diff --git a/ggml.c b/ggml.c index 37b16b7a9ce7f..0aa14ef9dca2f 100644 --- a/ggml.c +++ b/ggml.c @@ -112,8 +112,6 @@ typedef void * thread_ret_t; #endif -typedef pthread_t ggml_thread_t; - #ifdef GGML_USE_CPU_HBM #include #endif @@ -1722,57 +1720,57 @@ static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) { #endif // -// ggml context +// synchronization primitives // -struct ggml_context { - size_t mem_size; - void* mem_buffer; - bool mem_buffer_owned; - bool no_alloc; - bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers - - int n_objects; - - struct ggml_object* objects_begin; - struct ggml_object* objects_end; - - struct ggml_scratch scratch; - struct ggml_scratch scratch_save; +struct ggml_once { + atomic_int state; }; -struct ggml_context_container { - bool used; - - struct ggml_context context; +struct ggml_barrier { + atomic_uint phase; + atomic_int count; }; -struct ggml_compute_state_shared { - const struct ggml_cgraph* cgraph; - const struct ggml_cplan* cplan; - - int64_t perf_node_start_cycles; - int64_t perf_node_start_time_us; - - const int n_threads; - - // synchronization primitives - atomic_int n_active; // num active threads - atomic_int node_n; // active graph node - atomic_int node_task; // active graph node task phase - - ggml_abort_callback abort_callback; // abort ggml_graph_compute when true - void* abort_callback_data; +void ggml_once(struct ggml_once * once, void init(void)) { + int old = atomic_load_explicit(&once->state, memory_order_acquire); + if (!old && atomic_compare_exchange_strong_explicit(&once->state, &old, 1, + memory_order_acquire, + memory_order_relaxed)) { + init(); + atomic_store_explicit(&once->state, 2, memory_order_release); + return; + } + while (old == 1) { + old = atomic_load_explicit(&once->state, memory_order_acquire); + } +} - atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. -}; +int ggml_delay(int backoff) { + if (backoff < 12) { + volatile int i; + for (i = 0; i != 1 << backoff; i++) { + } + backoff++; + } else { + sched_yield(); + } + return backoff; +} -struct ggml_compute_state { - ggml_thread_t thrd; - int ith; - struct ggml_compute_state_shared* shared; - enum ggml_status ec; -}; +// creates barrier and blocks until all threads call this +void ggml_syncthreads(struct ggml_barrier * b, int nth) { + unsigned phase = atomic_load_explicit(&b->phase, memory_order_relaxed); + if (atomic_fetch_add_explicit(&b->count, 1, memory_order_acq_rel) + 1 == nth) { + atomic_store_explicit(&b->count, 0, memory_order_relaxed); + atomic_store_explicit(&b->phase, phase + 1, memory_order_release); + } else { + int backoff = 0; + while (atomic_load_explicit(&b->phase, memory_order_acquire) == phase) { + backoff = ggml_delay(backoff); + } + } +} // // fundamental operations @@ -2838,7 +2836,6 @@ static void ggml_setup_op_has_task_pass(void) { bool * p = GGML_OP_HAS_INIT; p[GGML_OP_ACC ] = true; - p[GGML_OP_MUL_MAT ] = true; p[GGML_OP_MUL_MAT_ID ] = true; p[GGML_OP_OUT_PROD ] = true; p[GGML_OP_SET ] = true; @@ -2859,6 +2856,32 @@ static void ggml_setup_op_has_task_pass(void) { } } +// +// ggml context +// + +struct ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + bool no_alloc; + bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers + + int n_objects; + + struct ggml_object * objects_begin; + struct ggml_object * objects_end; + + struct ggml_scratch scratch; + struct ggml_scratch scratch_save; +}; + +struct ggml_context_container { + bool used; + + struct ggml_context context; +}; + // // NUMA support // @@ -12302,101 +12325,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) { } #endif -static void ggml_compute_forward_mul_mat_one_chunk( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const int64_t num_rows_per_vec_dot, - const int64_t ir0_start, - const int64_t ir0_end, - const int64_t ir1_start, - const int64_t ir1_end) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const enum ggml_type type = src0->type; - - const bool src1_cont = ggml_is_contiguous(src1); - - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - - //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); - - // threads with no work simply yield (not sure if it helps) - if (ir0_start >= ir0_end || ir1_start >= ir1_end) { - return; - } - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; - - const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; - - // attempt to reduce false-sharing (does not seem to make a difference) - // 16 * 2, accounting for mmla kernels - float tmp[32]; - - for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { - for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { - const int64_t i13 = (ir1 / (ne12 * ne1)); - const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; - const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); - - // broadcast src0 into src1 - const int64_t i03 = i13 / r3; - const int64_t i02 = i12 / r2; - - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; - - const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); - - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char*)wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size - : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} - - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - } - - for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { - memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); - } - } - } - } -} - static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, - struct ggml_tensor * dst, - struct ggml_compute_state * state) { + struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -12411,6 +12342,9 @@ static void ggml_compute_forward_mul_mat( const enum ggml_type type = src0->type; + const bool src1_cont = ggml_is_contiguous(src1); + + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; int64_t const vec_dot_num_rows = type_traits[type].nrows; @@ -12431,17 +12365,15 @@ static void ggml_compute_forward_mul_mat( GGML_ASSERT(nb2 <= nb3); // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - UNUSED(r2); - UNUSED(r3); + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; // nb01 >= nb00 - src0 is not transposed // compute by src0 rows #if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { - if (params->ith == 0 && params->type == GGML_TASK_TYPE_COMPUTE) { + if (params->ith == 0) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } return; @@ -12454,31 +12386,25 @@ static void ggml_compute_forward_mul_mat( const size_t desired_wsize = ne13*ne12*ne_plane*sizeof(float); UNUSED(desired_wsize); - if (params->type == GGML_TASK_TYPE_INIT) { - if (type != GGML_TYPE_F32) { - assert(params->wsize >= desired_wsize); - // parallelize by src0 rows - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - // broadcast src0 into src1 across 2nd,3rd dimension - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; - - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane; - ggml_to_float_t const to_float = type_traits[type].to_float; - - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00); - } + if (type != GGML_TYPE_F32) { + assert(params->wsize >= desired_wsize); + // parallelize by src0 rows + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + // broadcast src0 into src1 across 2nd,3rd dimension + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane; + ggml_to_float_t const to_float = type_traits[type].to_float; + + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00); } } } - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; + ggml_syncthreads(params->barrier, params->nth); } // perform sgemm, parallelization controlled by blas lib @@ -12516,8 +12442,6 @@ static void ggml_compute_forward_mul_mat( #endif #if GGML_USE_LLAMAFILE - const bool src1_cont = ggml_is_contiguous(src1); - if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) @@ -12539,41 +12463,36 @@ static void ggml_compute_forward_mul_mat( UseGgmlGemm1:; #endif - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - atomic_store(&state->shared->current_chunk, nth); - if (src1->type != vec_dot_type) { - char * wdata = params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); - assert(params->wsize >= ne11*ne12*ne13*row_size); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + assert(params->wsize >= ne11*ne12*ne13*row_size); + GGML_ASSERT(src1->type == GGML_TYPE_F32); - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { + int chore = 0; + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + if (chore == ith) { from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; } + if (++chore == nth) { + chore = 0; + } + wdata += row_size; } } } - return; + ggml_syncthreads(params->barrier, params->nth); } - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); #if GGML_USE_LLAMAFILE if (src1->type != vec_dot_type) { - const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), @@ -12594,87 +12513,98 @@ UseGgmlGemm1:; UseGgmlGemm2:; #endif -#ifdef GGML_PERF - int chunks_executed = 0; - UNUSED(chunks_executed); -#endif + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = ne1*ne12*ne13; // src1 rows - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const int64_t nr0 = ne0; + //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); - // This is the size of the rest of the dimensions of the result - const int64_t nr1 = ne1 * ne2 * ne3; + // distribute the thread work across the inner or outer loop based on which one is larger - // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t num_rows_per_vec_dot = vec_dot_num_rows; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too - if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - num_rows_per_vec_dot = 1; - } + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; - // Now select a reasonable chunk size. - int chunk_size = 16; + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; - // We need to step up the size if it's small - if (nr0 == 1 || nr1 == 1) { - chunk_size = 64; + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); + + // threads with no work simply yield (not sure if it helps) + if (ir010 >= ir011 || ir110 >= ir111) { + sched_yield(); + return; } - // distribute the work across the inner or outer loop based on which one is larger - // The number of chunks in the 0/1 dim. - // CEIL(nr0/chunk_size) - int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; - int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); - // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. - // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 - // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. - if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { - // distribute the thread work across the inner or outer loop based on which one is larger - nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t nrc = vec_dot_num_rows; + // TODO: currently the mmla kernels support only even numbered rows/cols. + // this check can be removed once they are extended to support odd numbered rows/cols too + if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { + nrc = 1; } - // The number of elements in each chunk - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; + + // attempt to reduce false-sharing (does not seem to make a difference) + // 16 * 2, accounting for mmla kernels + float tmp[32]; - //if (ith == 0) - // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) { + const int64_t i13 = (ir1/(ne12*ne1)); + const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; + const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; + // broadcast src0 into src1 + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; - while (current_chunk < nchunk0 * nchunk1) { - const int64_t ith0 = current_chunk % nchunk0; - const int64_t ith1 = current_chunk / nchunk0; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - const int64_t ir0_start = dr0 * ith0; - const int64_t ir0_end = MIN(ir0_start + dr0, nr0); + const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03); - const int64_t ir1_start = dr1 * ith1; - const int64_t ir1_end = MIN(ir1_start + dr1, nr1); + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size + : (i11*nb11 + i12*nb12 + i13*nb13)); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); - ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} -#ifdef GGML_PERF - chunks_executed++; -#endif + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) { + vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc); + } - if (nth >= nchunk0 * nchunk1) { - break; + for (int cn = 0; cn < nrc; ++cn) { + memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } } - - current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1); } - -#ifdef GGML_PERF - // These numbers are useful when trying to measure how well the threading scheduling works. - //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1; - //float time = (ggml_perf_time_us() - t0); - //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed); -#endif } // ggml_compute_forward_mul_mat_id @@ -17833,7 +17763,7 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) { +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { @@ -17931,7 +17861,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT: { - ggml_compute_forward_mul_mat(params, tensor, state); + ggml_compute_forward_mul_mat(params, tensor); } break; case GGML_OP_MUL_MAT_ID: { @@ -19523,9 +19453,6 @@ void ggml_graph_clear(struct ggml_cgraph * cgraph) { // // thread data // -// synchronization is done via busy loops -// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops -// #ifdef __APPLE__ @@ -19549,6 +19476,8 @@ typedef int ggml_lock_t; #define GGML_LOCK_INITIALIZER 0 +typedef pthread_t ggml_thread_t; + #define ggml_thread_create pthread_create #define ggml_thread_join pthread_join @@ -19574,6 +19503,8 @@ typedef int ggml_lock_t; #define GGML_LOCK_INITIALIZER 0 +typedef pthread_t ggml_thread_t; + #define ggml_thread_create pthread_create #define ggml_thread_join pthread_join @@ -19653,6 +19584,32 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } static void clear_numa_thread_affinity(void) {} #endif +struct ggml_compute_state_shared { + const struct ggml_cgraph * cgraph; + const struct ggml_cplan * cplan; + + int64_t perf_node_start_cycles; + int64_t perf_node_start_time_us; + + const int n_threads; + + // synchronization primitives + atomic_int n_active; // num active threads + atomic_int node_n; // active graph node + atomic_int node_task; // active graph node task phase + struct ggml_barrier barrier; + + ggml_abort_callback abort_callback; // abort ggml_graph_compute when true + void * abort_callback_data; +}; + +struct ggml_compute_state { + ggml_thread_t thrd; + int ith; + struct ggml_compute_state_shared * shared; + enum ggml_status ec; +}; + static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; @@ -19914,39 +19871,27 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ return n_tasks; } -static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_compute_state * state, const bool do_yield) { +static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_compute_state * state) { // wait for other threads to finish const int last_node_n = * node_n; + int backoff = 0; while (true) { - if (do_yield) { - sched_yield(); - } - - * node_n = atomic_load(&state->shared->node_n); + * node_n = atomic_load_explicit(&state->shared->node_n, memory_order_acquire); if (* node_n != last_node_n) break; -#if defined(__SSE3__) - // Tell the processor we're spinning. It's a processor hint for spinlocks. - _mm_pause(); -#endif + backoff = ggml_delay(backoff); } } -static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_compute_state * state, const bool do_yield) { +static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_compute_state * state) { // wait for other threads to finish const int last_task_phase = * task_phase; + int backoff = 0; while (true) { - if (do_yield) { - sched_yield(); - } - - * task_phase = atomic_load(&state->shared->node_task); + * task_phase = atomic_load_explicit(&state->shared->node_task, memory_order_acquire); if (* task_phase != last_task_phase) break; -#if defined(__SSE3__) - // Tell the processor we're spinning. It's a processor hint for spinlocks. - _mm_pause(); -#endif + backoff = ggml_delay(backoff); } } @@ -19974,11 +19919,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { // all other threads are finished and spinning // do finalize and init here so we don't have synchronize again struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_TYPE_FINALIZE, - /*.ith =*/ 0, - /*.nth =*/ 0, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, + /*.type =*/ GGML_TASK_TYPE_FINALIZE, + /*.ith =*/ 0, + /*.nth =*/ 0, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.barrier =*/ &state->shared->barrier, }; if (node_n != -1) { @@ -19986,7 +19932,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_tensor * node = cgraph->nodes[node_n]; if (GGML_OP_HAS_FINALIZE[node->op]) { params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - ggml_compute_forward(¶ms, node, state); + ggml_compute_forward(¶ms, node); } ggml_graph_compute_perf_stats_node(node, state->shared); } @@ -20006,17 +19952,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /* INIT */ if (GGML_OP_HAS_INIT[node->op]) { params.type = GGML_TASK_TYPE_INIT; - ggml_compute_forward(¶ms, node, state); + ggml_compute_forward(¶ms, node); } // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, // they do something more efficient than spinning (?) params.type = GGML_TASK_TYPE_COMPUTE; - ggml_compute_forward(¶ms, node, state); + ggml_compute_forward(¶ms, node); if (GGML_OP_HAS_FINALIZE[node->op]) { params.type = GGML_TASK_TYPE_FINALIZE; - ggml_compute_forward(¶ms, node, state); + ggml_compute_forward(¶ms, node); } ggml_graph_compute_perf_stats_node(node, state->shared); @@ -20030,12 +19976,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } task_phase = GGML_TASK_TYPE_INIT; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_n, node_n); - atomic_store(&state->shared->node_task, task_phase); + atomic_store_explicit(&state->shared->n_active, n_threads, memory_order_release); + atomic_store_explicit(&state->shared->node_n, node_n, memory_order_release); + atomic_store_explicit(&state->shared->node_task, task_phase, memory_order_release); } else { - ggml_graph_compute_thread_sync_node(&node_n, state, false); - ggml_graph_compute_thread_sync_task(&task_phase, state, false); + ggml_graph_compute_thread_sync_node(&node_n, state); + ggml_graph_compute_thread_sync_task(&task_phase, state); } // check if we should stop @@ -20046,46 +19992,43 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads); struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_TYPE_INIT, - /*.ith =*/ state->ith, - /*.nth =*/ n_tasks, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, + /*.type =*/ GGML_TASK_TYPE_INIT, + /*.ith =*/ state->ith, + /*.nth =*/ n_tasks, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.barrier =*/ &state->shared->barrier, }; if (state->ith < n_tasks) { if (GGML_OP_HAS_INIT[node->op]) { - ggml_compute_forward(¶ms, node, state); + ggml_compute_forward(¶ms, node); } } - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - task_phase = GGML_TASK_TYPE_COMPUTE; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_task, task_phase); - } - else { - // TODO: this sched_yield can have significant impact on the performance - either positive or negative - // depending on the workload and the operating system. - // since it is not clear what is the best approach, it should potentially become user-configurable - // ref: https://github.com/ggerganov/ggml/issues/291 - // UPD: adding the do_yield flag seems to resolve the issue universally - const bool do_yield = node_n < 0 || cgraph->nodes[node_n]->op == GGML_OP_MUL_MAT; - ggml_graph_compute_thread_sync_task(&task_phase, state, do_yield); + if (GGML_OP_HAS_INIT[node->op]) { + if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { + task_phase = GGML_TASK_TYPE_COMPUTE; + atomic_store_explicit(&state->shared->n_active, n_threads, memory_order_release); + atomic_store_explicit(&state->shared->node_task, task_phase, memory_order_release); + } + else { + ggml_graph_compute_thread_sync_task(&task_phase, state); + } } if (state->ith < n_tasks) { params.type = GGML_TASK_TYPE_COMPUTE; - ggml_compute_forward(¶ms, node, state); + ggml_compute_forward(¶ms, node); } if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { task_phase = GGML_TASK_TYPE_FINALIZE; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_task, task_phase); + atomic_store_explicit(&state->shared->n_active, n_threads, memory_order_release); + atomic_store_explicit(&state->shared->node_task, task_phase, memory_order_release); } else { - ggml_graph_compute_thread_sync_task(&task_phase, state, false); + ggml_graph_compute_thread_sync_task(&task_phase, state); } } @@ -20325,9 +20268,9 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl /*.n_active =*/ n_threads, /*.node_n =*/ -1, /*.node_task =*/ GGML_TASK_TYPE_FINALIZE, + /*.barrier =*/ {0, 0}, /*.abort_callback =*/ NULL, /*.abort_callback_data =*/ NULL, - /*.current_chunk; =*/ 0, }; struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); diff --git a/ggml.h b/ggml.h index 35ac9110ceb17..c9e6dc738cf4c 100644 --- a/ggml.h +++ b/ggml.h @@ -680,6 +680,12 @@ extern "C" { GGML_TASK_TYPE_FINALIZE, }; + struct ggml_once; + struct ggml_barrier; + int ggml_delay(int backoff); + void ggml_syncthreads(struct ggml_barrier * b, int nth); + void ggml_once(struct ggml_once * once, void init(void)); + struct ggml_compute_params { enum ggml_task_type type; @@ -689,6 +695,8 @@ extern "C" { // work buffer for all threads size_t wsize; void * wdata; + + struct ggml_barrier *barrier; }; // numa strategies