From 7995f7acef1c0bc4ea1e3ac61ace5bf23c010c28 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Wed, 26 Nov 2025 12:36:16 +0000 Subject: [PATCH 1/5] Add Flux kontext example --- examples/flux_kontext_example.py | 209 +++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 examples/flux_kontext_example.py diff --git a/examples/flux_kontext_example.py b/examples/flux_kontext_example.py new file mode 100644 index 00000000..bae7c3fc --- /dev/null +++ b/examples/flux_kontext_example.py @@ -0,0 +1,209 @@ +# Flux inference with USP +# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py + +import functools + +import logging +import time +import torch +from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version +from typing import List, Optional + +if not has_valid_diffusers_version("flux_kontext"): + minimum_diffusers_version = get_minimum_diffusers_version("flux_kontext") + raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux-Kontext.") + +from diffusers import DiffusionPipeline, FluxKontextPipeline +from diffusers.utils import load_image + +from xfuser import xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_data_parallel_world_size, + get_data_parallel_rank, + get_runtime_state, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + is_dp_last_group, + initialize_runtime_state, + get_pipeline_parallel_world_size, +) + +from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxAttnProcessor + +def pad_to_sp_divisable(tensor: torch.Tensor, padding_length: int, dim: int) -> torch.Tensor: + + padding = torch.zeros( + *tensor.shape[:dim], padding_length, *tensor.shape[dim + 1 :], dtype=tensor.dtype, device=tensor.device + ) + tensor = torch.cat([tensor, padding], dim=dim) + return tensor + +def parallelize_transformer(pipe: DiffusionPipeline): + transformer = pipe.transformer + original_forward = transformer.forward + + @functools.wraps(transformer.__class__.forward) + def new_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + *args, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + **kwargs, + ): + + sp_world_size = get_sequence_parallel_world_size() + sequence_length = hidden_states.shape[1] + padding_length = sp_world_size - (sequence_length % sp_world_size) % sp_world_size + if padding_length > 0: + hidden_states = pad_to_sp_divisable(hidden_states, padding_length, dim=1) + img_ids = pad_to_sp_divisable(img_ids, padding_length, dim=0) + assert hidden_states.shape[0] % get_classifier_free_guidance_world_size() == 0, \ + f"Cannot split dim 0 of hidden_states ({hidden_states.shape[0]}) into {get_classifier_free_guidance_world_size()} parts." + if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size() != 0: + get_runtime_state().split_text_embed_in_sp = False + else: + get_runtime_state().split_text_embed_in_sp = True + + if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: + timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + if get_runtime_state().split_text_embed_in_sp: + txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + + + output = original_forward( + hidden_states, + encoder_hidden_states, + *args, + timestep=timestep, + img_ids=img_ids, + txt_ids=txt_ids, + **kwargs, + ) + + return_dict = not isinstance(output, tuple) + sample = output[0] + sample = get_sp_group().all_gather(sample, dim=-2) + sample = get_cfg_group().all_gather(sample, dim=0) + if padding_length > 0: + sample = sample[:, :-padding_length, :] + if return_dict: + return output.__class__(sample, *output[1:]) + return (sample, *output[1:]) + + new_forward = new_forward.__get__(transformer) + transformer.forward = new_forward + + for block in transformer.transformer_blocks + transformer.single_transformer_blocks: + block.attn.processor = xFuserFluxAttnProcessor() + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + engine_config, input_config = engine_args.create_config() + engine_config.runtime_config.dtype = torch.bfloat16 + local_rank = get_world_group().local_rank + + assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion." + + if not args.img_file_path: + raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.") + image = load_image(args.img_file_path) + + pipe = FluxKontextPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + torch_dtype=torch.bfloat16, + ) + + if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + else: + pipe = pipe.to(f"cuda:{local_rank}") + + parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + initialize_runtime_state(pipe, engine_config) + get_runtime_state().set_input_parameters( + batch_size=1, + num_inference_steps=input_config.num_inference_steps, + max_condition_sequence_length=512, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, + ) + + parallelize_transformer(pipe) + + if engine_config.runtime_config.use_torch_compile: + torch._inductor.config.reorder_for_compute_comm_overlap = True + pipe.transformer = torch.compile(pipe.transformer, mode="default") + + # one step to warmup the torch compiler + output = pipe( + height=input_config.height, + width=input_config.width, + max_area=input_config.height * input_config.width, + prompt=input_config.prompt, + num_inference_steps=1, + output_type=input_config.output_type, + guidance_scale=2.5, + image=image, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + ).images + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + height=input_config.height, + width=input_config.width, + prompt=input_config.prompt, + max_area=input_config.height * input_config.width, + num_inference_steps=input_config.num_inference_steps, + output_type=input_config.output_type, + guidance_scale=2.5, + image=image, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + ) + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + ) + if input_config.output_type == "pil": + dp_group_index = get_data_parallel_rank() + num_dp_groups = get_data_parallel_world_size() + dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups + if is_dp_last_group(): + for i, image in enumerate(output.images): + image_rank = dp_group_index * dp_batch_size + i + image_name = f"flux_kontext_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}_{input_config.height}x{input_config.width}.png" + image.save(f"./results/{image_name}") + print(f"image {i} saved to ./results/{image_name}") + + if get_world_group().rank == get_world_group().world_size - 1: + print( + f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB" + ) + get_runtime_state().destroy_distributed_env() + + +if __name__ == "__main__": + main() From 252c7262b027691e7c23383f18182984a40a6b02 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Wed, 26 Nov 2025 12:37:32 +0000 Subject: [PATCH 2/5] Update README.md and versions checks --- README.md | 2 ++ xfuser/config/diffusers.py | 1 + 2 files changed, 3 insertions(+) diff --git a/README.md b/README.md index 0f68a837..7f43d98a 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,7 @@ The following open-sourced DiT Models are released with xDiT in day 1. | [🎬 Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | ❎ | ✔️ | ❎ | ❎ | NA | | [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ | ❎ | [Report](./docs/performance/hunyuandit.md) | | [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ | ❎ | [Report](./docs/performance/flux.md) | +| [🟠 Flux Kontext](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) | ❎ | ✔️ | ❎ ️ | ❎ | NA | | [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ | ❎ | [Report](./docs/performance/pixart_alpha_legacy.md) | | [🟢 PixArt-alpha](https://huggingface.co/PixArt-alpha/PixArt-alpha) | ✔️ | ✔️ | ✔️ | ❎ | [Report](./docs/performance/pixart_alpha_legacy.md) | | [🟠 Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | ✔️ | ✔️ | ✔️ | ❎ | [Report](./docs/performance/sd3.md) | @@ -236,6 +237,7 @@ Below is a list of validated diffusers version requirements. If the model is not | Model Name | Diffusers version | | --- | --- | | [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) | >= 0.35.2 | +| [Flux Kontext](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) | >= 0.35.2 | | [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) | >= 0.35.2 | | [Wan2.1](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) | >= 0.35.2 | | [Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | >= 0.35.2 | diff --git a/xfuser/config/diffusers.py b/xfuser/config/diffusers.py index b709653a..b1cd0d28 100644 --- a/xfuser/config/diffusers.py +++ b/xfuser/config/diffusers.py @@ -4,6 +4,7 @@ DEFAULT_MINIMUM_DIFFUSERS_VERSION = "0.33.0" MINIMUM_DIFFUSERS_VERSIONS = { "flux": "0.35.2", + "flux_kontext": "0.35.2", "hunyuanvideo": "0.35.2", "wan": "0.35.2", } From 79a84c62dc0cc15cc36c8e404d83e1b003fa6724 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Wed, 26 Nov 2025 12:48:03 +0000 Subject: [PATCH 3/5] Fix missing padding fix --- examples/flux_kontext_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flux_kontext_example.py b/examples/flux_kontext_example.py index bae7c3fc..3e4743b9 100644 --- a/examples/flux_kontext_example.py +++ b/examples/flux_kontext_example.py @@ -62,7 +62,7 @@ def new_forward( sp_world_size = get_sequence_parallel_world_size() sequence_length = hidden_states.shape[1] - padding_length = sp_world_size - (sequence_length % sp_world_size) % sp_world_size + padding_length = (sp_world_size - (sequence_length % sp_world_size)) % sp_world_size if padding_length > 0: hidden_states = pad_to_sp_divisable(hidden_states, padding_length, dim=1) img_ids = pad_to_sp_divisable(img_ids, padding_length, dim=0) From cf226172dce35b624dc74a0777abff5b6c02414e Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Wed, 26 Nov 2025 14:48:38 +0200 Subject: [PATCH 4/5] Update README.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7f43d98a..e6d4b226 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ The following open-sourced DiT Models are released with xDiT in day 1. | [🎬 Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | ❎ | ✔️ | ❎ | ❎ | NA | | [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ | ❎ | [Report](./docs/performance/hunyuandit.md) | | [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ | ❎ | [Report](./docs/performance/flux.md) | -| [🟠 Flux Kontext](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) | ❎ | ✔️ | ❎ ️ | ❎ | NA | +| [🟠 Flux Kontext](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) | ❎ | ✔️ | ❎ | ❎ | NA | | [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ | ❎ | [Report](./docs/performance/pixart_alpha_legacy.md) | | [🟢 PixArt-alpha](https://huggingface.co/PixArt-alpha/PixArt-alpha) | ✔️ | ✔️ | ✔️ | ❎ | [Report](./docs/performance/pixart_alpha_legacy.md) | | [🟠 Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | ✔️ | ✔️ | ✔️ | ❎ | [Report](./docs/performance/sd3.md) | From a6d7cd0434f1e3c9a2a6cfb7f764d884f739d53a Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Wed, 26 Nov 2025 13:03:11 +0000 Subject: [PATCH 5/5] Fix typo --- examples/flux_kontext_example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/flux_kontext_example.py b/examples/flux_kontext_example.py index 3e4743b9..43eabe6f 100644 --- a/examples/flux_kontext_example.py +++ b/examples/flux_kontext_example.py @@ -36,7 +36,7 @@ from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxAttnProcessor -def pad_to_sp_divisable(tensor: torch.Tensor, padding_length: int, dim: int) -> torch.Tensor: +def pad_to_sp_divisible(tensor: torch.Tensor, padding_length: int, dim: int) -> torch.Tensor: padding = torch.zeros( *tensor.shape[:dim], padding_length, *tensor.shape[dim + 1 :], dtype=tensor.dtype, device=tensor.device @@ -64,8 +64,8 @@ def new_forward( sequence_length = hidden_states.shape[1] padding_length = (sp_world_size - (sequence_length % sp_world_size)) % sp_world_size if padding_length > 0: - hidden_states = pad_to_sp_divisable(hidden_states, padding_length, dim=1) - img_ids = pad_to_sp_divisable(img_ids, padding_length, dim=0) + hidden_states = pad_to_sp_divisible(hidden_states, padding_length, dim=1) + img_ids = pad_to_sp_divisible(img_ids, padding_length, dim=0) assert hidden_states.shape[0] % get_classifier_free_guidance_world_size() == 0, \ f"Cannot split dim 0 of hidden_states ({hidden_states.shape[0]}) into {get_classifier_free_guidance_world_size()} parts." if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size() != 0: