Skip to content

Commit 81959c7

Browse files
authored
[NewFeature]custom_allreduce support cudagraph recapture (#4305)
* custom_allreduce support cudagraph recapture * add shut_down/restart default group
1 parent 7c91907 commit 81959c7

File tree

7 files changed

+32
-3
lines changed

7 files changed

+32
-3
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,8 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
623623

624624
void free_shared_buffer(int64_t buffer);
625625

626+
void clear_ipc_handles(int64_t _fa);
627+
626628
// speculative decoding Kernel
627629
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
628630
const paddle::Tensor& input_ids,
@@ -1229,6 +1231,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
12291231

12301232
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
12311233

1234+
m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles");
1235+
12321236
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
12331237

12341238
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");

custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,14 @@ void register_graph_buffers(fptr_t _fa,
122122
for (int i = 0; i < handles.size(); i++) {
123123
bytes.emplace_back(handles[i].begin(), handles[i].end());
124124
}
125-
bytes.reserve(handles.size());
126125
fa->register_graph_buffers(bytes, offsets);
127126
}
128127

128+
void clear_ipc_handles(fptr_t _fa) {
129+
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
130+
fa->clear_ipc_handles();
131+
}
132+
129133
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
130134
int64_t size) {
131135

custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,15 @@ class CustomAllreduce {
517517
#undef KL
518518
}
519519

520-
~CustomAllreduce() {
520+
void clear_ipc_handles(){
521521
for (auto [_, ptr] : ipc_handles_) {
522522
CUDACHECK(cudaIpcCloseMemHandle(ptr));
523523
}
524+
ipc_handles_.clear();
525+
}
526+
527+
~CustomAllreduce() {
528+
clear_ipc_handles();
524529
}
525530
};
526531
} // namespace paddle

fastdeploy/distributed/communication.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
4242
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
4343

4444

45+
def custom_ar_clear_ipc_handles():
46+
global _TP_AR
47+
if _TP_AR is not None:
48+
_TP_AR.clear_ipc_handles()
49+
50+
4551
try:
4652

4753
@paddle.jit.marker.unified

fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
2626
from fastdeploy.model_executor.ops.gpu import (
2727
all_reduce,
28+
clear_ipc_handles,
2829
dispose,
2930
get_graph_buffer_ipc_meta,
3031
init_custom_all_reduce,
@@ -220,6 +221,9 @@ def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
220221
else:
221222
return self.all_reduce(input, input, registered=False)
222223

224+
def clear_ipc_handles(self):
225+
clear_ipc_handles(self._ptr)
226+
223227
def close(self):
224228
if self._ptr:
225229
dispose(self._ptr)

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525

2626
from fastdeploy import envs
2727
from fastdeploy.config import FDConfig
28-
from fastdeploy.distributed.communication import capture_custom_allreduce
28+
from fastdeploy.distributed.communication import (
29+
capture_custom_allreduce,
30+
custom_ar_clear_ipc_handles,
31+
)
2932
from fastdeploy.utils import get_logger
3033

3134
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
@@ -227,6 +230,7 @@ def _create_entry_dict(self):
227230
def clear_graph(self):
228231
""" """
229232
# Clear graphs
233+
custom_ar_clear_ipc_handles()
230234
for id, entry in self.concrete_size_entries.items():
231235
if entry.cuda_graph:
232236
del entry.cuda_graph

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def update_parameters(self, pid: int = 0) -> None:
6666

6767
# step1 : restart paddle process group
6868
if not self.first_load:
69+
paddle.distributed.restart_process_group()
6970
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
7071
if self.parallel_config.enable_expert_parallel:
7172
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
@@ -148,6 +149,7 @@ def clear_parameters(self, pid: int = 0) -> None:
148149
if self.parallel_config.enable_expert_parallel:
149150
paddle.distributed.barrier(self.parallel_config.ep_group)
150151
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
152+
paddle.distributed.shutdown_process_group()
151153
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
152154

153155
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):

0 commit comments

Comments
 (0)