diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index fd4dcfa941d4b..0abda85331dee 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -703,7 +703,7 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - std::unique_ptr cuda_graph; + std::vector> cuda_graphs; explicit ggml_backend_cuda_context(int device) : device(device), diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6d5d9aa54703b..cd4e8ccfe1198 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2413,13 +2413,19 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +// groups cgraph->nodes offsets per cuda_graph +struct cgraph_offset { + int begin; + int end; +}; + #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) { +static bool check_node_graph_compatibility_and_refresh_copy_ops(std::unique_ptr & cuda_graph, ggml_cgraph * cgraph, + std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph, cgraph_offset & offset) { // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->updated_kernel_arg.clear(); - for (int i = 0; i < cgraph->n_nodes; i++) { + cuda_graph->updated_kernel_arg.clear(); + for (int i = offset.begin; i < offset.end; i++) { ggml_tensor * node = cgraph->nodes[i]; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { @@ -2451,7 +2457,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud if (node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. - cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); + cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); // store a pointer to each copy op CUDA kernel to identify it later void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); if (!ptr) { @@ -2525,26 +2531,28 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } -static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { +static void maintain_cuda_graph(std::unique_ptr & cuda_graph, std::vector & ggml_cuda_cpy_fn_ptrs, + bool cuda_graph_update_required) { if (cuda_graph_update_required) { // Extract nodes from graph // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); + CUDA_CHECK(cudaGraphGetNodes(cuda_graph->graph, nullptr, &cuda_graph->num_nodes)); // Subsequent call with non-null argument gets nodes - cuda_ctx->cuda_graph->nodes.clear(); - cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); - cuda_ctx->cuda_graph->params.clear(); - cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); - if (cuda_ctx->cuda_graph->num_nodes > 0) { - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); + cuda_graph->nodes.clear(); + cuda_graph->nodes.resize(cuda_graph->num_nodes); + cuda_graph->params.clear(); + cuda_graph->params.resize(cuda_graph->num_nodes); + if (cuda_graph->num_nodes > 0) { + CUDA_CHECK(cudaGraphGetNodes(cuda_graph->graph, cuda_graph->nodes.data(), &cuda_graph->num_nodes)); // Loop over nodes, and extract kernel parameters from each node - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + for (size_t i = 0; i < cuda_graph->num_nodes; i++) { cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); + CUDA_CHECK(cudaGraphNodeGetType(cuda_graph->nodes[i], &node_type)); if (node_type == cudaGraphNodeTypeKernel) { - cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime + // Get params using runtime + cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_graph->nodes[i], &cuda_graph->params[i]); if (stat == cudaErrorInvalidDeviceFunction) { // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. // We don't need to update blas nodes, so clear error and move on. @@ -2560,54 +2568,55 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto // replace that argument with the updated value in the CUDA graph // on update steps, the live parameters will already be captured int k = 0; - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { - char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); - cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; - CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); + for (size_t i = 0; i < cuda_graph->num_nodes; i++) { + if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_graph->params[i].func) > 0) { + char ** updated_kernel_arg_ptr = cuda_graph->updated_kernel_arg.at(k++); + cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; + CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_graph->nodes[i], &cuda_graph->params[i])); } } } } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static bool is_cuda_graph_update_required(std::unique_ptr & cuda_graph, ggml_cgraph * cgraph, + cgraph_offset & offset) { bool cuda_graph_update_required = false; - if (cuda_ctx->cuda_graph->instance == nullptr) { + if (cuda_graph->instance == nullptr) { cuda_graph_update_required = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + if (cuda_graph->ggml_graph_properties.size() != (size_t)(offset.end - offset.begin)) { cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + cuda_graph->ggml_graph_properties.resize((offset.end - offset.begin)); } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token - for (int i = 0; i < cgraph->n_nodes; i++) { + for (int i = offset.begin; i < offset.end; i++) { bool has_matching_properties = true; if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i - offset.begin]); } if (!has_matching_properties) { cuda_graph_update_required = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i - offset.begin]); } return cuda_graph_update_required; } -static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { +static void update_cuda_graph_executable(std::unique_ptr & cuda_graph) { cudaGraphExecUpdateResultInfo result_info; #ifdef __HIP_PLATFORM_AMD__ hipGraphNode_t errorNode; - hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); + hipError_t stat = hipGraphExecUpdate(cuda_graph->instance, cuda_graph->graph, &errorNode, &result_info); #else - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + cudaError_t stat = cudaGraphExecUpdate(cuda_graph->instance, cuda_graph->graph, &result_info); #endif if (stat == cudaErrorGraphExecUpdateFailure) { #ifndef NDEBUG @@ -2617,24 +2626,24 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate (void)cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance)); + cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_graph->instance, cuda_graph->graph, NULL, NULL, 0)); } else { GGML_ASSERT(stat == cudaSuccess); } } #endif -static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph, - bool & cuda_graph_update_required) { +static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, [[maybe_unused]] std::unique_ptr & cuda_graph, + ggml_cgraph * cgraph, [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs, + bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required, cgraph_offset & offset) { while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { - for (int i = 0; i < cgraph->n_nodes; i++) { + for (int i = offset.begin; i < offset.end; i++) { ggml_tensor * node = cgraph->nodes[i]; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { @@ -2662,12 +2671,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx #ifdef USE_CUDA_GRAPH if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture - if (cuda_ctx->cuda_graph->graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); - cuda_ctx->cuda_graph->graph = nullptr; + if (cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_graph->graph)); + cuda_graph->graph = nullptr; } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph->graph)); graph_evaluated_or_captured = true; // CUDA graph has been captured } else { graph_evaluated_or_captured = true; // ggml graph has been directly evaluated @@ -2675,18 +2684,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + if (cuda_graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cuda_graph->instance, cuda_graph->graph, NULL, NULL, 0)); } // Perform update to graph (if required for this token), and change copy parameter (required for every token) - maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); + maintain_cuda_graph(cuda_graph, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); // Update graph executable - update_cuda_graph_executable(cuda_ctx); + update_cuda_graph_executable(cuda_graph); // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); + CUDA_CHECK(cudaGraphLaunch(cuda_graph->instance, cuda_ctx->stream())); #else graph_evaluated_or_captured = true; #endif // USE_CUDA_GRAPH @@ -2702,70 +2711,100 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, // kernel parameters which need updated in the graph for each token std::vector ggml_cuda_cpy_fn_ptrs; + // Heuristic to minimize GPU idle time. Work is split over several CUDA graphs, + // to overlap graph building (CPU) and graph execution (GPU). + // The first graphs are small to minimize the time in which the CPU prepares work and the GPU is idle. + // After that, graph building (CPU) is done in parallel to the execution of another previously built graph (GPU). + int first_graph_subset = 20; + int second_graph_subset = 50; + int remaining_graph_subset = 100; + int remaining_nodes = (cgraph->n_nodes - first_graph_subset) - second_graph_subset; + int num_cuda_graphs_required = 2 + (remaining_nodes / remaining_graph_subset); + cuda_ctx->cuda_graphs.resize(num_cuda_graphs_required); + cgraph_offset offset {0,0}; + + for (size_t i = 0; i < cuda_ctx->cuda_graphs.size(); i++) { + auto & cuda_graph = cuda_ctx->cuda_graphs[i]; + + offset.begin = offset.end; + if (i == 0) offset.end += first_graph_subset; + if (i == 1) offset.end += second_graph_subset; + if (i >= 2) offset.end += remaining_graph_subset; + + // last graph does the rest + if ((i + 1) == cuda_ctx->cuda_graphs.size()) offset.end = cgraph->n_nodes; + + // special case for graphs smaller than the ramp-up heuristic + if (cgraph->n_nodes <= first_graph_subset + second_graph_subset) { + offset.end = cgraph->n_nodes; + if (i > 0) break; + } + #ifdef USE_CUDA_GRAPH - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - // Objects required for CUDA Graph - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); - } + // Objects required for CUDA Graph + if (cuda_graph == nullptr) { + cuda_graph = std::make_unique(); + } - bool use_cuda_graph = true; - bool cuda_graph_update_required = false; + bool use_cuda_graph = true; + bool cuda_graph_update_required = false; - if (cuda_ctx->cuda_graph->graph == nullptr) { - if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { - cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; + if (cuda_graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { + cuda_graph->disable_due_to_gpu_arch = true; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); #endif + } } - } - // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, - // or previous graph capture failure. - // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->cuda_graph->disable_due_to_gpu_arch - || cuda_ctx->cuda_graph->disable_due_to_too_many_updates - || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { - use_cuda_graph = false; - } + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, + // or previous graph capture failure. + // Also disable for multi-gpu for now. TO DO investigate + if (disable_cuda_graphs_due_to_env + || cuda_graph->disable_due_to_gpu_arch + || cuda_graph->disable_due_to_too_many_updates + || cuda_graph->disable_due_to_failed_graph_capture) { + use_cuda_graph = false; + } - if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + if (use_cuda_graph) { + cuda_graph_update_required = is_cuda_graph_update_required(cuda_graph, cgraph, offset); - use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, - ggml_cuda_cpy_fn_ptrs, use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_graph, cgraph, + ggml_cuda_cpy_fn_ptrs, use_cuda_graph, offset); - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; - } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; - } + // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. + if (use_cuda_graph && cuda_graph_update_required) { + cuda_graph->number_consecutive_updates++; + } else { + cuda_graph->number_consecutive_updates = 0; + } - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; + if (cuda_graph->number_consecutive_updates >= 4) { + cuda_graph->disable_due_to_too_many_updates = true; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); #endif + } } - } - if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture - CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); - } + if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + } #else - bool use_cuda_graph = false; - bool cuda_graph_update_required = false; + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; #endif // USE_CUDA_GRAPH - bool graph_evaluated_or_captured = false; - - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + bool graph_evaluated_or_captured = false; + evaluate_and_capture_cuda_graph(cuda_ctx, cuda_graph, cgraph, ggml_cuda_cpy_fn_ptrs, + graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required, offset); + } return GGML_STATUS_SUCCESS; }