@@ -151,14 +151,22 @@ void CudfHashJoinBuild::noMoreInput() {
151151 };
152152
153153 auto cudf_tables = std::vector<std::unique_ptr<cudf::table>>(inputs_.size ());
154+ auto input_streams = std::vector<rmm::cuda_stream_view>(inputs_.size ());
154155 for (int i = 0 ; i < inputs_.size (); i++) {
155156 VELOX_CHECK_NOT_NULL (inputs_[i]);
157+ input_streams[i] = inputs_[i]->stream ();
156158 cudf_tables[i] = inputs_[i]->release ();
157159 }
158- auto tbl = concatenateTables (std::move (cudf_tables));
160+ auto stream = cudfGlobalStreamPool ().get_stream ();
161+ cudf::detail::join_streams (input_streams, stream);
162+ auto tbl = concatenateTables (std::move (cudf_tables), stream);
163+
164+ // Release input data after synchronizing
165+ stream.synchronize ();
166+ input_streams.clear ();
167+ cudf_tables.clear ();
159168
160169 // Release input data
161- cudf::get_default_stream ().synchronize ();
162170 inputs_.clear ();
163171
164172 VELOX_CHECK_NOT_NULL (tbl);
@@ -246,6 +254,7 @@ RowVectorPtr CudfHashJoinProbe::getOutput() {
246254 }
247255 auto cudf_input = std::dynamic_pointer_cast<CudfVector>(input_);
248256 VELOX_CHECK_NOT_NULL (cudf_input);
257+ auto stream = cudf_input->stream ();
249258 auto tbl = cudf_input->release ();
250259 if (cudfDebugEnabled ()) {
251260 std::cout << " Probe table number of columns: " << tbl->num_columns ()
@@ -307,8 +316,8 @@ RowVectorPtr CudfHashJoinProbe::getOutput() {
307316 hb.get (),
308317 hashObject_.has_value ());
309318 }
310- auto const [left_join_indices, right_join_indices] =
311- hb-> inner_join ( tbl->view ().select (probe_key_indices));
319+ auto const [left_join_indices, right_join_indices] = hb-> inner_join (
320+ tbl->view ().select (probe_key_indices), std:: nullopt , stream );
312321 auto left_indices_span =
313322 cudf::device_span<cudf::size_type const >{*left_join_indices};
314323 auto right_indices_span =
@@ -361,8 +370,10 @@ RowVectorPtr CudfHashJoinProbe::getOutput() {
361370 auto left_indices_col = cudf::column_view{left_indices_span};
362371 auto right_indices_col = cudf::column_view{right_indices_span};
363372 auto constexpr oob_policy = cudf::out_of_bounds_policy::DONT_CHECK;
364- auto left_result = cudf::gather (left_input, left_indices_col, oob_policy);
365- auto right_result = cudf::gather (right_input, right_indices_col, oob_policy);
373+ auto left_result =
374+ cudf::gather (left_input, left_indices_col, oob_policy, stream);
375+ auto right_result =
376+ cudf::gather (right_input, right_indices_col, oob_policy, stream);
366377
367378 if (cudfDebugEnabled ()) {
368379 std::cout << " Left result number of columns: " << left_result->num_columns ()
@@ -391,7 +402,7 @@ RowVectorPtr CudfHashJoinProbe::getOutput() {
391402 return nullptr ;
392403 }
393404 return std::make_shared<CudfVector>(
394- pool (), outputType, size, std::move (cudf_output));
405+ pool (), outputType, size, std::move (cudf_output), stream );
395406}
396407
397408exec::BlockingReason CudfHashJoinProbe::isBlocked (ContinueFuture* future) {
0 commit comments