Skip to content

Commit a627cd6

Browse files
committed
add stream to cudfFilterProject
1 parent a18b429 commit a627cd6

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

velox/experimental/cudf/exec/CudfFilterProject.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ RowVectorPtr CudfFilterProject::getOutput() {
258258

259259
auto cudf_input = std::dynamic_pointer_cast<CudfVector>(input_);
260260
VELOX_CHECK_NOT_NULL(cudf_input);
261+
auto stream = cudf_input->stream();
261262
auto input_table_columns = cudf_input->release()->release();
262263
// add ast unsupported precomputed columns to input_table
263264
// Works only directly on column in input table, not intermediate columns
@@ -267,13 +268,13 @@ RowVectorPtr CudfFilterProject::getOutput() {
267268
auto new_column = cudf::datetime::extract_datetime_component(
268269
input_table_columns[dependent_column_index]->view(),
269270
cudf::datetime::datetime_component::YEAR,
270-
cudf::get_default_stream(),
271+
stream,
271272
cudf::get_current_device_resource_ref());
272273
input_table_columns.emplace_back(std::move(new_column));
273274
} else if (ins_name == "length") {
274275
auto new_column = cudf::strings::count_characters(
275276
input_table_columns[dependent_column_index]->view(),
276-
cudf::get_default_stream(),
277+
stream,
277278
cudf::get_current_device_resource_ref());
278279
input_table_columns.emplace_back(std::move(new_column));
279280
} else {
@@ -289,7 +290,7 @@ RowVectorPtr CudfFilterProject::getOutput() {
289290
auto col = cudf::compute_column(
290291
cudf_table_view,
291292
tree.back(),
292-
cudf::get_default_stream(),
293+
stream,
293294
cudf::get_current_device_resource_ref());
294295
columns.emplace_back(std::move(col));
295296
}
@@ -308,6 +309,7 @@ RowVectorPtr CudfFilterProject::getOutput() {
308309
}
309310

310311
auto output_table = std::make_unique<cudf::table>(std::move(output_columns));
312+
stream.synchronize();
311313
auto const num_columns = output_table->num_columns();
312314
auto const size = output_table->num_rows();
313315
if (cudfDebugEnabled()) {
@@ -316,7 +318,7 @@ RowVectorPtr CudfFilterProject::getOutput() {
316318
}
317319

318320
auto cudf_output = std::make_shared<CudfVector>(
319-
input_->pool(), outputType_, size, std::move(output_table));
321+
input_->pool(), outputType_, size, std::move(output_table), stream);
320322
input_.reset();
321323
if (num_columns == 0 or size == 0) {
322324
return nullptr;

0 commit comments

Comments
 (0)