@@ -767,25 +767,25 @@ def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
767767
768768 # Cross-attention timesteps - compress these too
769769 av_ca_audio_scale_shift_timestep , _ = self .av_ca_audio_scale_shift_adaln_single (
770- timestep . max (). expand_as ( a_timestep_flat ) ,
770+ a_timestep_flat ,
771771 {"resolution" : None , "aspect_ratio" : None },
772772 batch_size = batch_size ,
773773 hidden_dtype = hidden_dtype ,
774774 )
775775 av_ca_video_scale_shift_timestep , _ = self .av_ca_video_scale_shift_adaln_single (
776- a_timestep . max (). expand_as ( timestep_flat ) ,
776+ timestep_flat ,
777777 {"resolution" : None , "aspect_ratio" : None },
778778 batch_size = batch_size ,
779779 hidden_dtype = hidden_dtype ,
780780 )
781781 av_ca_a2v_gate_noise_timestep , _ = self .av_ca_a2v_gate_adaln_single (
782- a_timestep .max ().expand_as (timestep_flat ) * av_ca_factor ,
782+ a_timestep_scaled .max ().expand_as (timestep_flat ) * av_ca_factor ,
783783 {"resolution" : None , "aspect_ratio" : None },
784784 batch_size = batch_size ,
785785 hidden_dtype = hidden_dtype ,
786786 )
787787 av_ca_v2a_gate_noise_timestep , _ = self .av_ca_v2a_gate_adaln_single (
788- timestep .max ().expand_as (a_timestep_flat ) * av_ca_factor ,
788+ timestep_scaled .max ().expand_as (a_timestep_flat ) * av_ca_factor ,
789789 {"resolution" : None , "aspect_ratio" : None },
790790 batch_size = batch_size ,
791791 hidden_dtype = hidden_dtype ,
0 commit comments