Skip to content

Commit 9f758e9

Browse files
Add Flux Kontext support (#592)
* Add Flux kontext example * Update README.md and versions checks * Fix missing padding fix * Update README.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Fix typo --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ccba9d5 commit 9f758e9

File tree

3 files changed

+212
-0
lines changed

3 files changed

+212
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ The following open-sourced DiT Models are released with xDiT in day 1.
117117
| [🎬 Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) || ✔️ ||| NA |
118118
| [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/hunyuandit.md) |
119119
| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ || [Report](./docs/performance/flux.md) |
120+
| [🟠 Flux Kontext](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) || ✔️ ||| NA |
120121
| [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/pixart_alpha_legacy.md) |
121122
| [🟢 PixArt-alpha](https://huggingface.co/PixArt-alpha/PixArt-alpha) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/pixart_alpha_legacy.md) |
122123
| [🟠 Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/sd3.md) |
@@ -236,6 +237,7 @@ Below is a list of validated diffusers version requirements. If the model is not
236237
| Model Name | Diffusers version |
237238
| --- | --- |
238239
| [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) | >= 0.35.2 |
240+
| [Flux Kontext](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) | >= 0.35.2 |
239241
| [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) | >= 0.35.2 |
240242
| [Wan2.1](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) | >= 0.35.2 |
241243
| [Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | >= 0.35.2 |

examples/flux_kontext_example.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Flux inference with USP
2+
# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py
3+
4+
import functools
5+
6+
import logging
7+
import time
8+
import torch
9+
from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version
10+
from typing import List, Optional
11+
12+
if not has_valid_diffusers_version("flux_kontext"):
13+
minimum_diffusers_version = get_minimum_diffusers_version("flux_kontext")
14+
raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux-Kontext.")
15+
16+
from diffusers import DiffusionPipeline, FluxKontextPipeline
17+
from diffusers.utils import load_image
18+
19+
from xfuser import xFuserArgs
20+
from xfuser.config import FlexibleArgumentParser
21+
from xfuser.core.distributed import (
22+
get_world_group,
23+
get_data_parallel_world_size,
24+
get_data_parallel_rank,
25+
get_runtime_state,
26+
get_classifier_free_guidance_world_size,
27+
get_classifier_free_guidance_rank,
28+
get_cfg_group,
29+
get_sequence_parallel_world_size,
30+
get_sequence_parallel_rank,
31+
get_sp_group,
32+
is_dp_last_group,
33+
initialize_runtime_state,
34+
get_pipeline_parallel_world_size,
35+
)
36+
37+
from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxAttnProcessor
38+
39+
def pad_to_sp_divisible(tensor: torch.Tensor, padding_length: int, dim: int) -> torch.Tensor:
40+
41+
padding = torch.zeros(
42+
*tensor.shape[:dim], padding_length, *tensor.shape[dim + 1 :], dtype=tensor.dtype, device=tensor.device
43+
)
44+
tensor = torch.cat([tensor, padding], dim=dim)
45+
return tensor
46+
47+
def parallelize_transformer(pipe: DiffusionPipeline):
48+
transformer = pipe.transformer
49+
original_forward = transformer.forward
50+
51+
@functools.wraps(transformer.__class__.forward)
52+
def new_forward(
53+
self,
54+
hidden_states: torch.Tensor,
55+
encoder_hidden_states: Optional[torch.Tensor] = None,
56+
*args,
57+
timestep: torch.LongTensor = None,
58+
img_ids: torch.Tensor = None,
59+
txt_ids: torch.Tensor = None,
60+
**kwargs,
61+
):
62+
63+
sp_world_size = get_sequence_parallel_world_size()
64+
sequence_length = hidden_states.shape[1]
65+
padding_length = (sp_world_size - (sequence_length % sp_world_size)) % sp_world_size
66+
if padding_length > 0:
67+
hidden_states = pad_to_sp_divisible(hidden_states, padding_length, dim=1)
68+
img_ids = pad_to_sp_divisible(img_ids, padding_length, dim=0)
69+
assert hidden_states.shape[0] % get_classifier_free_guidance_world_size() == 0, \
70+
f"Cannot split dim 0 of hidden_states ({hidden_states.shape[0]}) into {get_classifier_free_guidance_world_size()} parts."
71+
if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size() != 0:
72+
get_runtime_state().split_text_embed_in_sp = False
73+
else:
74+
get_runtime_state().split_text_embed_in_sp = True
75+
76+
if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]:
77+
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
78+
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
79+
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
80+
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
81+
if get_runtime_state().split_text_embed_in_sp:
82+
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
83+
img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
84+
if get_runtime_state().split_text_embed_in_sp:
85+
txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
86+
87+
88+
output = original_forward(
89+
hidden_states,
90+
encoder_hidden_states,
91+
*args,
92+
timestep=timestep,
93+
img_ids=img_ids,
94+
txt_ids=txt_ids,
95+
**kwargs,
96+
)
97+
98+
return_dict = not isinstance(output, tuple)
99+
sample = output[0]
100+
sample = get_sp_group().all_gather(sample, dim=-2)
101+
sample = get_cfg_group().all_gather(sample, dim=0)
102+
if padding_length > 0:
103+
sample = sample[:, :-padding_length, :]
104+
if return_dict:
105+
return output.__class__(sample, *output[1:])
106+
return (sample, *output[1:])
107+
108+
new_forward = new_forward.__get__(transformer)
109+
transformer.forward = new_forward
110+
111+
for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
112+
block.attn.processor = xFuserFluxAttnProcessor()
113+
114+
115+
def main():
116+
parser = FlexibleArgumentParser(description="xFuser Arguments")
117+
args = xFuserArgs.add_cli_args(parser).parse_args()
118+
engine_args = xFuserArgs.from_cli_args(args)
119+
engine_config, input_config = engine_args.create_config()
120+
engine_config.runtime_config.dtype = torch.bfloat16
121+
local_rank = get_world_group().local_rank
122+
123+
assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
124+
125+
if not args.img_file_path:
126+
raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.")
127+
image = load_image(args.img_file_path)
128+
129+
pipe = FluxKontextPipeline.from_pretrained(
130+
pretrained_model_name_or_path=engine_config.model_config.model,
131+
torch_dtype=torch.bfloat16,
132+
)
133+
134+
if args.enable_sequential_cpu_offload:
135+
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
136+
logging.info(f"rank {local_rank} sequential CPU offload enabled")
137+
else:
138+
pipe = pipe.to(f"cuda:{local_rank}")
139+
140+
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
141+
142+
initialize_runtime_state(pipe, engine_config)
143+
get_runtime_state().set_input_parameters(
144+
batch_size=1,
145+
num_inference_steps=input_config.num_inference_steps,
146+
max_condition_sequence_length=512,
147+
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
148+
)
149+
150+
parallelize_transformer(pipe)
151+
152+
if engine_config.runtime_config.use_torch_compile:
153+
torch._inductor.config.reorder_for_compute_comm_overlap = True
154+
pipe.transformer = torch.compile(pipe.transformer, mode="default")
155+
156+
# one step to warmup the torch compiler
157+
output = pipe(
158+
height=input_config.height,
159+
width=input_config.width,
160+
max_area=input_config.height * input_config.width,
161+
prompt=input_config.prompt,
162+
num_inference_steps=1,
163+
output_type=input_config.output_type,
164+
guidance_scale=2.5,
165+
image=image,
166+
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
167+
).images
168+
169+
torch.cuda.reset_peak_memory_stats()
170+
start_time = time.time()
171+
172+
output = pipe(
173+
height=input_config.height,
174+
width=input_config.width,
175+
prompt=input_config.prompt,
176+
max_area=input_config.height * input_config.width,
177+
num_inference_steps=input_config.num_inference_steps,
178+
output_type=input_config.output_type,
179+
guidance_scale=2.5,
180+
image=image,
181+
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
182+
)
183+
end_time = time.time()
184+
elapsed_time = end_time - start_time
185+
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
186+
187+
parallel_info = (
188+
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
189+
)
190+
if input_config.output_type == "pil":
191+
dp_group_index = get_data_parallel_rank()
192+
num_dp_groups = get_data_parallel_world_size()
193+
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
194+
if is_dp_last_group():
195+
for i, image in enumerate(output.images):
196+
image_rank = dp_group_index * dp_batch_size + i
197+
image_name = f"flux_kontext_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}_{input_config.height}x{input_config.width}.png"
198+
image.save(f"./results/{image_name}")
199+
print(f"image {i} saved to ./results/{image_name}")
200+
201+
if get_world_group().rank == get_world_group().world_size - 1:
202+
print(
203+
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
204+
)
205+
get_runtime_state().destroy_distributed_env()
206+
207+
208+
if __name__ == "__main__":
209+
main()

xfuser/config/diffusers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
DEFAULT_MINIMUM_DIFFUSERS_VERSION = "0.33.0"
55
MINIMUM_DIFFUSERS_VERSIONS = {
66
"flux": "0.35.2",
7+
"flux_kontext": "0.35.2",
78
"hunyuanvideo": "0.35.2",
89
"wan": "0.35.2",
910
}

0 commit comments

Comments
 (0)