@@ -399,6 +399,58 @@ def execute(cls, audio) -> IO.NodeOutput:
399399
400400 separate = execute # TODO: remove
401401
402+ class JoinAudioChannels (IO .ComfyNode ):
403+ @classmethod
404+ def define_schema (cls ):
405+ return IO .Schema (
406+ node_id = "JoinAudioChannels" ,
407+ display_name = "Join Audio Channels" ,
408+ description = "Joins left and right mono audio channels into a stereo audio." ,
409+ category = "audio" ,
410+ inputs = [
411+ IO .Audio .Input ("audio_left" ),
412+ IO .Audio .Input ("audio_right" ),
413+ ],
414+ outputs = [
415+ IO .Audio .Output (display_name = "audio" ),
416+ ],
417+ )
418+
419+ @classmethod
420+ def execute (cls , audio_left , audio_right ) -> IO .NodeOutput :
421+ waveform_left = audio_left ["waveform" ]
422+ sample_rate_left = audio_left ["sample_rate" ]
423+ waveform_right = audio_right ["waveform" ]
424+ sample_rate_right = audio_right ["sample_rate" ]
425+
426+ if waveform_left .shape [1 ] != 1 or waveform_right .shape [1 ] != 1 :
427+ raise ValueError ("AudioJoin: Both input audios must be mono." )
428+
429+ # Handle different sample rates by resampling to the higher rate
430+ waveform_left , waveform_right , output_sample_rate = match_audio_sample_rates (
431+ waveform_left , sample_rate_left , waveform_right , sample_rate_right
432+ )
433+
434+ # Handle different lengths by trimming to the shorter length
435+ length_left = waveform_left .shape [- 1 ]
436+ length_right = waveform_right .shape [- 1 ]
437+
438+ if length_left != length_right :
439+ min_length = min (length_left , length_right )
440+ if length_left > min_length :
441+ logging .info (f"JoinAudioChannels: Trimming left channel from { length_left } to { min_length } samples." )
442+ waveform_left = waveform_left [..., :min_length ]
443+ if length_right > min_length :
444+ logging .info (f"JoinAudioChannels: Trimming right channel from { length_right } to { min_length } samples." )
445+ waveform_right = waveform_right [..., :min_length ]
446+
447+ # Join the channels into stereo
448+ left_channel = waveform_left [..., 0 :1 , :]
449+ right_channel = waveform_right [..., 0 :1 , :]
450+ stereo_waveform = torch .cat ([left_channel , right_channel ], dim = 1 )
451+
452+ return IO .NodeOutput ({"waveform" : stereo_waveform , "sample_rate" : output_sample_rate })
453+
402454
403455def match_audio_sample_rates (waveform_1 , sample_rate_1 , waveform_2 , sample_rate_2 ):
404456 if sample_rate_1 != sample_rate_2 :
@@ -616,6 +668,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
616668 RecordAudio ,
617669 TrimAudioDuration ,
618670 SplitAudioChannels ,
671+ JoinAudioChannels ,
619672 AudioConcat ,
620673 AudioMerge ,
621674 AudioAdjustVolume ,
0 commit comments