From 6318fa6aac30f125bf29518cc0acee16bbb21338 Mon Sep 17 00:00:00 2001 From: Kunjan patel Date: Wed, 25 Jun 2025 19:34:13 +0000 Subject: [PATCH] Fix namedsharding --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index a3be8e13..85725c9a 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -397,7 +397,7 @@ def __call__( num_channels_latents=num_channel_latents, ) - data_sharding = NamedSharding(self.devices_array, P()) + data_sharding = NamedSharding(self.mesh, P()) if len(prompt) % jax.device_count() == 0: data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))