|
| 1 | +# Flux inference with USP |
| 2 | +# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py |
| 3 | + |
| 4 | +import functools |
| 5 | + |
| 6 | +import logging |
| 7 | +import time |
| 8 | +import torch |
| 9 | +from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version |
| 10 | +from typing import List, Optional |
| 11 | + |
| 12 | +if not has_valid_diffusers_version("flux_kontext"): |
| 13 | + minimum_diffusers_version = get_minimum_diffusers_version("flux_kontext") |
| 14 | + raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux-Kontext.") |
| 15 | + |
| 16 | +from diffusers import DiffusionPipeline, FluxKontextPipeline |
| 17 | +from diffusers.utils import load_image |
| 18 | + |
| 19 | +from xfuser import xFuserArgs |
| 20 | +from xfuser.config import FlexibleArgumentParser |
| 21 | +from xfuser.core.distributed import ( |
| 22 | + get_world_group, |
| 23 | + get_data_parallel_world_size, |
| 24 | + get_data_parallel_rank, |
| 25 | + get_runtime_state, |
| 26 | + get_classifier_free_guidance_world_size, |
| 27 | + get_classifier_free_guidance_rank, |
| 28 | + get_cfg_group, |
| 29 | + get_sequence_parallel_world_size, |
| 30 | + get_sequence_parallel_rank, |
| 31 | + get_sp_group, |
| 32 | + is_dp_last_group, |
| 33 | + initialize_runtime_state, |
| 34 | + get_pipeline_parallel_world_size, |
| 35 | +) |
| 36 | + |
| 37 | +from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxAttnProcessor |
| 38 | + |
| 39 | +def pad_to_sp_divisible(tensor: torch.Tensor, padding_length: int, dim: int) -> torch.Tensor: |
| 40 | + |
| 41 | + padding = torch.zeros( |
| 42 | + *tensor.shape[:dim], padding_length, *tensor.shape[dim + 1 :], dtype=tensor.dtype, device=tensor.device |
| 43 | + ) |
| 44 | + tensor = torch.cat([tensor, padding], dim=dim) |
| 45 | + return tensor |
| 46 | + |
| 47 | +def parallelize_transformer(pipe: DiffusionPipeline): |
| 48 | + transformer = pipe.transformer |
| 49 | + original_forward = transformer.forward |
| 50 | + |
| 51 | + @functools.wraps(transformer.__class__.forward) |
| 52 | + def new_forward( |
| 53 | + self, |
| 54 | + hidden_states: torch.Tensor, |
| 55 | + encoder_hidden_states: Optional[torch.Tensor] = None, |
| 56 | + *args, |
| 57 | + timestep: torch.LongTensor = None, |
| 58 | + img_ids: torch.Tensor = None, |
| 59 | + txt_ids: torch.Tensor = None, |
| 60 | + **kwargs, |
| 61 | + ): |
| 62 | + |
| 63 | + sp_world_size = get_sequence_parallel_world_size() |
| 64 | + sequence_length = hidden_states.shape[1] |
| 65 | + padding_length = (sp_world_size - (sequence_length % sp_world_size)) % sp_world_size |
| 66 | + if padding_length > 0: |
| 67 | + hidden_states = pad_to_sp_divisible(hidden_states, padding_length, dim=1) |
| 68 | + img_ids = pad_to_sp_divisible(img_ids, padding_length, dim=0) |
| 69 | + assert hidden_states.shape[0] % get_classifier_free_guidance_world_size() == 0, \ |
| 70 | + f"Cannot split dim 0 of hidden_states ({hidden_states.shape[0]}) into {get_classifier_free_guidance_world_size()} parts." |
| 71 | + if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size() != 0: |
| 72 | + get_runtime_state().split_text_embed_in_sp = False |
| 73 | + else: |
| 74 | + get_runtime_state().split_text_embed_in_sp = True |
| 75 | + |
| 76 | + if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: |
| 77 | + timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] |
| 78 | + hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] |
| 79 | + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] |
| 80 | + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] |
| 81 | + if get_runtime_state().split_text_embed_in_sp: |
| 82 | + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] |
| 83 | + img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] |
| 84 | + if get_runtime_state().split_text_embed_in_sp: |
| 85 | + txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] |
| 86 | + |
| 87 | + |
| 88 | + output = original_forward( |
| 89 | + hidden_states, |
| 90 | + encoder_hidden_states, |
| 91 | + *args, |
| 92 | + timestep=timestep, |
| 93 | + img_ids=img_ids, |
| 94 | + txt_ids=txt_ids, |
| 95 | + **kwargs, |
| 96 | + ) |
| 97 | + |
| 98 | + return_dict = not isinstance(output, tuple) |
| 99 | + sample = output[0] |
| 100 | + sample = get_sp_group().all_gather(sample, dim=-2) |
| 101 | + sample = get_cfg_group().all_gather(sample, dim=0) |
| 102 | + if padding_length > 0: |
| 103 | + sample = sample[:, :-padding_length, :] |
| 104 | + if return_dict: |
| 105 | + return output.__class__(sample, *output[1:]) |
| 106 | + return (sample, *output[1:]) |
| 107 | + |
| 108 | + new_forward = new_forward.__get__(transformer) |
| 109 | + transformer.forward = new_forward |
| 110 | + |
| 111 | + for block in transformer.transformer_blocks + transformer.single_transformer_blocks: |
| 112 | + block.attn.processor = xFuserFluxAttnProcessor() |
| 113 | + |
| 114 | + |
| 115 | +def main(): |
| 116 | + parser = FlexibleArgumentParser(description="xFuser Arguments") |
| 117 | + args = xFuserArgs.add_cli_args(parser).parse_args() |
| 118 | + engine_args = xFuserArgs.from_cli_args(args) |
| 119 | + engine_config, input_config = engine_args.create_config() |
| 120 | + engine_config.runtime_config.dtype = torch.bfloat16 |
| 121 | + local_rank = get_world_group().local_rank |
| 122 | + |
| 123 | + assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion." |
| 124 | + |
| 125 | + if not args.img_file_path: |
| 126 | + raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.") |
| 127 | + image = load_image(args.img_file_path) |
| 128 | + |
| 129 | + pipe = FluxKontextPipeline.from_pretrained( |
| 130 | + pretrained_model_name_or_path=engine_config.model_config.model, |
| 131 | + torch_dtype=torch.bfloat16, |
| 132 | + ) |
| 133 | + |
| 134 | + if args.enable_sequential_cpu_offload: |
| 135 | + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) |
| 136 | + logging.info(f"rank {local_rank} sequential CPU offload enabled") |
| 137 | + else: |
| 138 | + pipe = pipe.to(f"cuda:{local_rank}") |
| 139 | + |
| 140 | + parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") |
| 141 | + |
| 142 | + initialize_runtime_state(pipe, engine_config) |
| 143 | + get_runtime_state().set_input_parameters( |
| 144 | + batch_size=1, |
| 145 | + num_inference_steps=input_config.num_inference_steps, |
| 146 | + max_condition_sequence_length=512, |
| 147 | + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, |
| 148 | + ) |
| 149 | + |
| 150 | + parallelize_transformer(pipe) |
| 151 | + |
| 152 | + if engine_config.runtime_config.use_torch_compile: |
| 153 | + torch._inductor.config.reorder_for_compute_comm_overlap = True |
| 154 | + pipe.transformer = torch.compile(pipe.transformer, mode="default") |
| 155 | + |
| 156 | + # one step to warmup the torch compiler |
| 157 | + output = pipe( |
| 158 | + height=input_config.height, |
| 159 | + width=input_config.width, |
| 160 | + max_area=input_config.height * input_config.width, |
| 161 | + prompt=input_config.prompt, |
| 162 | + num_inference_steps=1, |
| 163 | + output_type=input_config.output_type, |
| 164 | + guidance_scale=2.5, |
| 165 | + image=image, |
| 166 | + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), |
| 167 | + ).images |
| 168 | + |
| 169 | + torch.cuda.reset_peak_memory_stats() |
| 170 | + start_time = time.time() |
| 171 | + |
| 172 | + output = pipe( |
| 173 | + height=input_config.height, |
| 174 | + width=input_config.width, |
| 175 | + prompt=input_config.prompt, |
| 176 | + max_area=input_config.height * input_config.width, |
| 177 | + num_inference_steps=input_config.num_inference_steps, |
| 178 | + output_type=input_config.output_type, |
| 179 | + guidance_scale=2.5, |
| 180 | + image=image, |
| 181 | + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), |
| 182 | + ) |
| 183 | + end_time = time.time() |
| 184 | + elapsed_time = end_time - start_time |
| 185 | + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") |
| 186 | + |
| 187 | + parallel_info = ( |
| 188 | + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" |
| 189 | + ) |
| 190 | + if input_config.output_type == "pil": |
| 191 | + dp_group_index = get_data_parallel_rank() |
| 192 | + num_dp_groups = get_data_parallel_world_size() |
| 193 | + dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups |
| 194 | + if is_dp_last_group(): |
| 195 | + for i, image in enumerate(output.images): |
| 196 | + image_rank = dp_group_index * dp_batch_size + i |
| 197 | + image_name = f"flux_kontext_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}_{input_config.height}x{input_config.width}.png" |
| 198 | + image.save(f"./results/{image_name}") |
| 199 | + print(f"image {i} saved to ./results/{image_name}") |
| 200 | + |
| 201 | + if get_world_group().rank == get_world_group().world_size - 1: |
| 202 | + print( |
| 203 | + f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB" |
| 204 | + ) |
| 205 | + get_runtime_state().destroy_distributed_env() |
| 206 | + |
| 207 | + |
| 208 | +if __name__ == "__main__": |
| 209 | + main() |
0 commit comments