Skip to content

[Bug]: vLLM sleep experiences segmentation fault when used in TRL #16993

@toslali-ibm

Description

@toslali-ibm

Your current environment

I am utilizing vLLM sleep in HF - TRL to efficiently manage GPU memory between training and generation. See my draft PR. The training is completed successfully, but I see a segmentation fault error at the end. I am not seeing this error in vLLM==0.7.3, but observing the error in 0.8.0, 0.8.1, and so on.

The logs of my run `err.log`
2025-04-22 15:34:26 - INFO - __main__ - *** Save model ***
[INFO|trainer.py:3984] 2025-04-22 15:34:28,362 >> Saving model checkpoint to trainer_output
[INFO|configuration_utils.py:419] 2025-04-22 15:34:28,366 >> Configuration saved in trainer_output/config.json
[INFO|configuration_utils.py:911] 2025-04-22 15:34:28,367 >> Configuration saved in trainer_output/generation_config.json
[rank4]:[W422 15:34:30.032759954 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank6]:[W422 15:34:30.076030919 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W422 15:34:30.136720629 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[INFO|modeling_utils.py:3572] 2025-04-22 15:34:31,111 >> Model weights saved in trainer_output/model.safetensors
[INFO|tokenization_utils_base.py:2510] 2025-04-22 15:34:31,113 >> tokenizer config file saved in trainer_output/tokenizer_config.json
[INFO|tokenization_utils_base.py:2519] 2025-04-22 15:34:31,114 >> Special tokens file saved in trainer_output/special_tokens_map.json
2025-04-22 15:34:31 - INFO - __main__ - Model saved to trainer_output
[INFO|configuration_utils.py:419] 2025-04-22 15:34:31,369 >> Configuration saved in trainer_output/config.json
[grpo-experiment-mert-experiments-master-0:19399:0:19399] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
==== backtrace (tid:  19399) ====
 0 0x0000000000042520 __sigaction()  ???:0
=================================
terminate called after throwing an instance of 'c10::Error'
  what():  Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8f2196c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f8f21915b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7f8ed248e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7f8f21d45b78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7f8f21d4620e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7f8f21d5cafa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7f8f21d48329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7f8f19a864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7f8f191a5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7f8f191a64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x55f46de26fd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x55f46dde9c52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x55f46ddea991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x55f46ddea78c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x55f46ddea877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x55f46de26aa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x55f46de7ea87 in /usr/bin/python)
frame #17: <unknown function> + 0x129ebc (0x55f46ddddebc in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x55f46df18970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x55f46df144c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x55f46df05913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x55f46dedc02d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7f8f5a34bd90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7f8f5a34be40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x55f46dedbf25 in /usr/bin/python)

[grpo-experiment-mert-experiments-master-0:19394:0:19394] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
terminate called after throwing an instance of 'c10::Error'
[grpo-experiment-mert-experiments-master-0:19398:0:19398] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x400)
  what():  Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fee9536c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7fee95315b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7fee45e8e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7fee95745b78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7fee9574620e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7fee9575cafa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7fee95748329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7fee8d4864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7fee8cba5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7fee8cba64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x564c79c2afd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x564c79bedc52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x564c79bee991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x564c79bee78c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x564c79bee877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x564c79c2aaa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x564c79c82a87 in /usr/bin/python)
frame #17: <unknown function> + 0x129ebc (0x564c79be1ebc in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x564c79d1c970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x564c79d184c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x564c79d09913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x564c79ce002d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7feecdd44d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7feecdd44e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x564c79cdff25 in /usr/bin/python)

[grpo-experiment-mert-experiments-master-0:19395:0:19395] Caught signal 11 (Segmentation fault: Sent by the kernel at address (nil))
==== backtrace (tid:  19394) ====
 0 0x0000000000042520 __sigaction()  ???:0
=================================
[grpo-experiment-mert-experiments-master-0:19397:0:19397] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
==== backtrace (tid:  19395) ====
 0 0x0000000000042520 __sigaction()  ???:0
 1 0x0000000000022b76 c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::release_block()  CUDACachingAllocator.cpp:0
 2 0x000000000002320e c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::release_blocks()  CUDACachingAllocator.cpp:0
 3 0x0000000000039afa c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::release_cached_blocks()  :0
 4 0x0000000000025329 c10::cuda::MemPool::~MemPool()  ???:0
 5 0x0000000000df74f0 pybind11::class_<c10::cuda::MemPool, std::shared_ptr<c10::cuda::MemPool> >::dealloc()  :0
 6 0x0000000000516907 pybind11::detail::clear_instance()  :0
 7 0x00000000005174d1 pybind11_object_dealloc()  :0
 8 0x0000000000172fd1 PyObject_DelItem()  ???:0
 9 0x0000000000135c52 _Py_CheckFunctionResult()  ???:0
10 0x0000000000136991 _Py_CheckFunctionResult()  ???:0
11 0x000000000013678c _Py_CheckFunctionResult()  ???:0
12 0x0000000000136877 _Py_CheckFunctionResult()  ???:0
13 0x0000000000172aa0 PyObject_DelItem()  ???:0
14 0x00000000001caa87 PyDict_Clear()  ???:0
15 0x0000000000129ebc PyObject_GC_Del()  ???:0
16 0x0000000000264970 PyMarshal_ReadLongFromFile()  ???:0
17 0x00000000002604c8 Py_FinalizeEx()  ???:0
18 0x0000000000251913 Py_RunMain()  ???:0
19 0x000000000022802d Py_BytesMain()  ???:0
20 0x0000000000029d90 __libc_init_first()  ???:0
21 0x0000000000029e40 __libc_start_main()  ???:0
22 0x0000000000227f25 _start()  ???:0
=================================
==== backtrace (tid:  19397) ====
 0 0x0000000000042520 __sigaction()  ???:0
=================================
[rank0]:[W422 15:34:34.190553574 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[grpo-experiment-mert-experiments-master-0:19392:0:19392] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
==== backtrace (tid:  19392) ====
 0 0x0000000000042520 __sigaction()  ???:0
=================================

🐛 Describe the bug

Segmentation fault occurs when distributed training is complete.

I tried to reproduce it in a simpler script (see below). I don't get the segmentation fault, but observe a critical warning, which may be a hint:

[rank0]:[W422 15:49:17.205132298 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
Simple vllm inf script `inf.py`
from vllm import LLM, SamplingParams
import time

# Create prompts, the same across all ranks
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
llm = LLM(
    model="facebook/opt-125m",
    tensor_parallel_size=2,
    distributed_executor_backend="external_launcher",
    enable_sleep_mode=True,
    seed=1
)

outputs = llm.generate(prompts, sampling_params)

# all ranks will have the same outputs
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

llm.sleep(level=2)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

CC @youkaichao, @fabianlim , @fingertap

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions