@@ -2479,9 +2479,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24792479
24802480 bool use_cuda_graph = true ;
24812481 bool cuda_graph_update_required = false ;
2482- // vector of pointers to CUDA cpy kernels, which are required to identify
2483- // kernel parameters which need updated in the graph for each token
2484- std::vector<void *> ggml_cuda_cpy_fn_ptrs;
24852482
24862483 if (cuda_ctx->cuda_graph ->graph == nullptr ) {
24872484 if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < CC_AMPERE) {
@@ -2527,7 +2524,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25272524 }
25282525
25292526 // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2530- cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
25312527 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
25322528 ggml_tensor * node = cgraph->nodes [i];
25332529
@@ -2554,16 +2550,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25542550#endif
25552551 }
25562552
2557- if (node->op == GGML_OP_CPY) {
2558- // store the copy op parameter which changes with each token.
2559- cuda_ctx->cuda_graph ->updated_kernel_arg .push_back ((char **) &(node->src [1 ]->data ));
2560- // store a pointer to each copy op CUDA kernel to identify it later
2561- void * ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2562- if (std::find (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), ptr) == ggml_cuda_cpy_fn_ptrs.end ()) {
2563- ggml_cuda_cpy_fn_ptrs.push_back (ptr);
2564- }
2565- }
2566-
25672553 if (!use_cuda_graph) {
25682554 break ;
25692555 }
@@ -2653,64 +2639,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26532639 CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
26542640 }
26552641
2656- // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2657-
26582642 if (cuda_graph_update_required) {
2659- // Extract nodes from graph
2660- // First call with null argument gets number of nodes in graph
2661- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2662- // Subsequent call with non-null argument gets nodes
2663- cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2664- cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2665- if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2666- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
2667-
2668- // Loop over nodes, and extract kernel parameters from each node
2669- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2670- cudaGraphNodeType node_type;
2671- CUDA_CHECK (cudaGraphNodeGetType (cuda_ctx->cuda_graph ->nodes [i], &node_type));
2672- if (node_type == cudaGraphNodeTypeKernel) {
2673- cudaError_t stat = cudaGraphKernelNodeGetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]); // Get params using runtime
2674- if (stat == cudaErrorInvalidDeviceFunction) {
2675- // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2676- // We don't need to update blas nodes, so clear error and move on.
2677- cudaGetLastError ();
2678- } else {
2679- GGML_ASSERT (stat == cudaSuccess);
2680- }
2681- }
2682- }
2683- }
2684- }
2685-
2686- // One of the arguments to the copy kernel is updated for each token, hence we need to
2687- // replace that argument with the updated value in the CUDA graph
2688- if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2689- int k = 0 ;
2690- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2691- if (count (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), cuda_ctx->cuda_graph ->params [i].func ) > 0 ) {
2692- char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph ->updated_kernel_arg .at (k++);
2693- cuda_ctx->cuda_graph ->params [i].kernelParams [1 ] = updated_kernel_arg_ptr;
2694- CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]));
2695- }
2696- }
2697- }
2698-
2699- // Update graph executable
2700- cudaGraphExecUpdateResultInfo result_info;
2701- cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2702- if (stat == cudaErrorGraphExecUpdateFailure) {
2643+ // Update graph executable
2644+ cudaGraphExecUpdateResultInfo result_info;
2645+ cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2646+ if (stat == cudaErrorGraphExecUpdateFailure) {
27032647#ifndef NDEBUG
2704- GGML_CUDA_LOG_ERROR (" %s: CUDA graph update failed\n " , __func__);
2648+ GGML_CUDA_LOG_ERROR (" %s: CUDA graph update failed\n " , __func__);
27052649#endif
2706- // The pre-existing graph exec cannot be updated due to violated constraints
2707- // so instead clear error and re-instantiate
2708- cudaGetLastError ();
2709- CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2710- cuda_ctx->cuda_graph ->instance = nullptr ;
2711- CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2712- } else {
2713- GGML_ASSERT (stat == cudaSuccess);
2650+ // The pre-existing graph exec cannot be updated due to violated constraints
2651+ // so instead clear error and re-instantiate
2652+ cudaGetLastError ();
2653+ CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2654+ cuda_ctx->cuda_graph ->instance = nullptr ;
2655+ CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2656+ } else {
2657+ GGML_ASSERT (stat == cudaSuccess);
2658+ }
27142659 }
27152660 // Launch graph
27162661 CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
0 commit comments