Skip to content

Commit 6602817

Browse files
committed
treewalk secondary
1 parent 7fe1f6f commit 6602817

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed

libgadget/treewalk.c

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,68 @@ static struct CommBuffer ev_secondary(struct CommBuffer * imports, struct ImpExp
686686
return res_imports;
687687
}
688688

689+
static struct CommBuffer ev_secondary_gpu(struct CommBuffer * imports, struct ImpExpCounts* counts, TreeWalk * tw, const struct gravshort_tree_params * TreeParams_ptr)
690+
{
691+
struct CommBuffer res_imports = {0};
692+
alloc_commbuffer(&res_imports, counts->NTask, 1);
693+
res_imports.databuf = (char *) mymalloc2("ImportResult", counts->Nimport * tw->result_type_elsize);
694+
695+
MPI_Datatype type;
696+
MPI_Type_contiguous(tw->result_type_elsize, MPI_BYTE, &type);
697+
MPI_Type_commit(&type);
698+
int * complete_array = ta_malloc("completes", int, imports->nrequest_all);
699+
700+
int tot_completed = 0;
701+
/* Test each request in turn until it completes*/
702+
while(tot_completed < imports->nrequest_all) {
703+
int complete_cnt = MPI_UNDEFINED;
704+
/* Check for some completed requests: note that cleanup is performed if the requests are complete.
705+
* There may be only 1 completed request, and we need to wait again until we have more.*/
706+
MPI_Waitsome(imports->nrequest_all, imports->rdata_all, &complete_cnt, complete_array, MPI_STATUSES_IGNORE);
707+
/* This happens if all requests are MPI_REQUEST_NULL. It should never be hit*/
708+
if (complete_cnt == MPI_UNDEFINED)
709+
break;
710+
int j;
711+
for(j = 0; j < complete_cnt; j++) {
712+
const int i = complete_array[j];
713+
/* Note the task number index is not the index in the request array (some tasks were skipped because we have zero exports)! */
714+
const int task = imports->rqst_task[i];
715+
const int64_t nimports_task = counts->Import_count[task];
716+
// message(1, "starting at %d with %d for iport %d task %d\n", counts->Import_offset[task], counts->Import_count[task], i, task);
717+
char * databufstart = imports->databuf + counts->Import_offset[task] * tw->query_type_elsize;
718+
char * dataresultstart = res_imports.databuf + counts->Import_offset[task] * tw->result_type_elsize;
719+
/* This sends each set of imports to a parallel for loop. This may lead to suboptimal resource allocation if only a small number of imports come from a processor.
720+
* If there are a large number of importing ranks each with a small number of imports, a better scheme could be to send each chunk to a separate openmp task.
721+
* However, each openmp task by default only uses 1 thread. One may explicitly enable openmp nested parallelism, but I think that is not safe,
722+
* or it would be enabled by default.*/
723+
// #pragma omp parallel
724+
// {
725+
// int64_t j;
726+
// LocalTreeWalk lv[1];
727+
728+
// ev_init_thread(tw, lv);
729+
// lv->mode = TREEWALK_GHOSTS;
730+
// #pragma omp for
731+
// for(j = 0; j < nimports_task; j++) {
732+
// TreeWalkQueryBase * input = (TreeWalkQueryBase *) (databufstart + j * tw->query_type_elsize);
733+
// TreeWalkResultBase * output = (TreeWalkResultBase *) (dataresultstart + j * tw->result_type_elsize);
734+
// treewalk_init_result(tw, output, input);
735+
// lv->target = -1;
736+
// tw->visit(input, output, lv);
737+
// }
738+
// }
739+
run_treewalk_secondary_kernel(tw, P, TreeParams_ptr, databufstart, dataresultstart, nimports_task);
740+
/* Send the completed data back*/
741+
res_imports.rqst_task[res_imports.nrequest_all] = task;
742+
MPI_Isend(dataresultstart, nimports_task, type, task, 101923, counts->comm, &res_imports.rdata_all[res_imports.nrequest_all++]);
743+
tot_completed++;
744+
}
745+
};
746+
myfree(complete_array);
747+
MPI_Type_free(&type);
748+
return res_imports;
749+
}
750+
689751
static struct ImpExpCounts
690752
ev_export_import_counts(TreeWalk * tw, MPI_Comm comm)
691753
{
@@ -879,12 +941,14 @@ treewalk_run(TreeWalk * tw, int * active_set, size_t size, struct gravshort_tree
879941
message(0, "Starting ev_primary (cpu) for %s with %ld particles\n", tw->ev_label, tw->WorkSetSize);
880942
ev_primary(tw); // cpu version
881943
#else
882-
if (TreeParams_ptr == NULL)
944+
if (TreeParams_ptr == NULL) {
883945
message(0, "Starting ev_primary (cpu) for %s with %ld particles\n", tw->ev_label, tw->WorkSetSize);
884946
ev_primary(tw); // cpu version still used for FoF now
885-
else
947+
}
948+
else {
886949
message(0, "Starting ev_primary (gpu) for %s with %ld particles\n", tw->ev_label, tw->WorkSetSize);
887950
ev_primary_gpu(tw, TreeParams_ptr); /* do local particles and prepare export list */
951+
}
888952
#endif
889953
message(0, "Finished ev_primary for %s with %ld particles\n", tw->ev_label, tw->WorkSetSize);
890954
}
@@ -897,7 +961,21 @@ treewalk_run(TreeWalk * tw, int * active_set, size_t size, struct gravshort_tree
897961
/* Posts recvs to get the export results (which are sent in ev_secondary).*/
898962
struct CommBuffer res_exports = {0};
899963
ev_recv_export_result(&res_exports, &counts, tw);
964+
#ifdef TREE_CPU
965+
message(0, "Starting ev_secondary (cpu) for %s with %ld particles\n", tw->ev_label);
900966
struct CommBuffer res_imports = ev_secondary(&imports, &counts, tw);
967+
#else
968+
struct CommBuffer res_imports;
969+
if (TreeParams_ptr == NULL) {
970+
message(0, "Starting ev_secondary (cpu) for %s with %ld particles\n", tw->ev_label);
971+
res_imports = ev_secondary(&imports, &counts, tw); // cpu version still used for FoF now
972+
}
973+
else {
974+
message(0, "Starting ev_secondary (gpu) for %s\n", tw->ev_label);
975+
res_imports = ev_secondary_gpu(&imports, &counts, tw, TreeParams_ptr);
976+
}
977+
#endif
978+
message(0, "Finished ev_secondary for %s\n", tw->ev_label);
901979
// report_memory_usage(tw->ev_label);
902980
free_commbuffer(&imports);
903981
tend = second();

libgadget/treewalk_kernel.cu

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,36 @@ void run_treewalk_kernel(TreeWalk *tw, struct particle_data *particles, const st
514514
// message(0, "CUDA error: %s\n", cudaGetErrorString(err));
515515
// }
516516
}
517+
518+
__global__ void treewalk_secondary_kernel(TreeWalk *tw, struct particle_data *particles, const struct gravshort_tree_params * TreeParams_ptr, char* databufstart, char* dataresultstart, const int64_t nimports_task) {
519+
520+
// Use a direct instance rather than an array
521+
LocalTreeWalk lv;
522+
ev_init_thread_device(tw, &lv);
523+
lv.mode = TREEWALK_GHOSTS;
524+
525+
int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
526+
527+
if (tid < nimports_task) {
528+
529+
TreeWalkQueryGravShort * input = (TreeWalkQueryGravShort *) (databufstart + tid * tw->query_type_elsize);
530+
TreeWalkResultGravShort * output = (TreeWalkResultGravShort *) (dataresultstart + tid * tw->result_type_elsize);
531+
532+
// Initialize query and result using device functions
533+
// treewalk_init_query_device(tw, &input, i, NULL, particles);
534+
treewalk_init_result_device(tw, output, input);
535+
536+
// Perform treewalk for particle
537+
lv.target = -1;
538+
force_treeev_shortrange_device(input, output, &lv, TreeParams_ptr, particles);
539+
540+
}
541+
}
542+
543+
void run_treewalk_secondary_kernel(TreeWalk *tw, struct particle_data *particles, const struct gravshort_tree_params * TreeParams_ptr, char* databufstart, char* dataresultstart, const int64_t nimports_task) {
544+
// workset is NULL at a PM step
545+
int threadsPerBlock = 256;
546+
int blocks = (nimports_task + threadsPerBlock - 1) / threadsPerBlock;
547+
treewalk_secondary_kernel<<<blocks, threadsPerBlock>>>(tw, particles, TreeParams_ptr, databufstart, dataresultstart, nimports_task);
548+
cudaDeviceSynchronize();
549+
}

libgadget/treewalk_kernel.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ void run_treewalk_kernel(TreeWalk *tw, struct particle_data *particles, const st
1313

1414
void run_gravshort_fill_ntab(const enum ShortRangeForceWindowType ShortRangeForceWindowType, const double Asmth);
1515

16+
void run_treewalk_secondary_kernel(TreeWalk *tw, struct particle_data *particles, const struct gravshort_tree_params * TreeParams_ptr, char* databufstart, char* dataresultstart, const int64_t nimports_task);
17+
1618
#endif // TREEWALK_KERNEL_H

0 commit comments

Comments
 (0)