Skip to content

Commit 9241930

Browse files
authored
Revert "Support cpu tensor transfer with NIXL in GPU Objects" (#56026)
Reverts #55793 fixing lint check
1 parent a5d032b commit 9241930

File tree

8 files changed

+29
-79
lines changed

8 files changed

+29
-79
lines changed

python/ray/experimental/channel/serialization_context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ def serialize_tensor(
9797
from ray.experimental.channel import ChannelContext
9898

9999
ctx = ChannelContext.get_current()
100-
if self._use_external_transport and (
101-
ctx._torch_device is None or ctx._torch_device == tensor.device
102-
):
100+
if self._use_external_transport and tensor.device == ctx.torch_device:
103101
# External transport is enabled and we found a tensor that matches
104102
# our device. Add the actual tensor to a buffer. The buffer of
105103
# tensors should later be popped by the caller and sent via

python/ray/experimental/collective/collective_tensor_transport.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,8 @@ def __ray_get_tensor_transport_metadata__(
4040
# it could take arbitrarily long and we don't want to trigger a spurious
4141
# timeout.
4242
gpu_object = gpu_object_store.wait_and_get_object(obj_id)
43-
tensor_meta = []
44-
device = None
45-
if gpu_object:
46-
device = gpu_object[0].device
47-
for t in gpu_object:
48-
if t.device.type != device.type:
49-
raise ValueError(
50-
"All tensors in one GPU object must be the same device type."
51-
)
52-
tensor_meta.append((t.shape, t.dtype))
5343
return CollectiveTransportMetadata(
54-
tensor_meta=tensor_meta,
55-
tensor_device=device,
44+
tensor_meta=[(t.shape, t.dtype) for t in gpu_object],
5645
)
5746

5847
# Submit a Ray actor task to the source actor to get the tensor metadata.
@@ -141,11 +130,10 @@ def recv_multiple_tensors(
141130
def send_multiple_tensors(
142131
tensors: List["torch.Tensor"],
143132
communicator_metadata: CollectiveCommunicatorMetadata,
133+
device: "torch.device",
144134
):
145135
import ray.util.collective as collective
146136

147-
device = tensor_transport_metadata.tensor_device
148-
149137
for tensor in tensors:
150138
if tensor.device.type != device.type:
151139
# TODO(swang): Right now there is no way to catch this error

python/ray/experimental/collective/nixl_tensor_transport.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,14 @@ def __ray_get_tensor_transport_metadata__(
4545
from ray.util.collective.collective import get_group_handle
4646

4747
nixl_backend: NixlBackend = get_group_handle(NIXL_GROUP_NAME)
48-
device = None
49-
tensor_meta = []
5048
if gpu_object:
5149
serialized_descs, agent_meta = nixl_backend.get_nixl_metadata(
5250
gpu_object
5351
)
54-
# We assume all tensors in one GPU object have the same device type.
55-
device = gpu_object[0].device
56-
for t in gpu_object:
57-
if t.device.type != device.type:
58-
raise ValueError(
59-
"All tensors in one GPU object must be the same device type."
60-
)
61-
tensor_meta.append((t.shape, t.dtype))
6252
else:
6353
serialized_descs, agent_meta = None, None
6454
return NixlTransportMetadata(
65-
tensor_meta=tensor_meta,
66-
tensor_device=device,
55+
tensor_meta=[(t.shape, t.dtype) for t in gpu_object],
6756
nixl_serialized_descs=serialized_descs,
6857
nixl_agent_meta=agent_meta,
6958
)

python/ray/experimental/collective/tensor_transport_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,13 @@ def recv_multiple_tensors(
143143
def send_multiple_tensors(
144144
tensors: List["torch.Tensor"],
145145
communicator_metadata: CommunicatorMetadata,
146+
device: "torch.device",
146147
):
147148
"""
148149
Send multiple tensors to the destination actor.
149150
150151
Args:
151152
tensors: The tensors to send.
152153
communicator_metadata: The communicator metadata for the send/recv operation.
154+
device: The device to send the tensors to.
153155
"""

python/ray/experimental/collective/util.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, TYPE_CHECKING
1+
from typing import Tuple
22
from contextlib import closing
33
import socket
44

@@ -11,9 +11,6 @@
1111
CollectiveTensorTransport,
1212
)
1313

14-
if TYPE_CHECKING:
15-
import torch
16-
1714
# Singleton instances for tensor transport managers
1815
_nixl_tensor_transport_manager = None
1916
_collective_tensor_transport_manager = None
@@ -44,18 +41,6 @@ def get_tensor_transport_manager(
4441
raise ValueError(f"Unsupported tensor transport protocol: {tensor_transport}")
4542

4643

47-
def device_match_transport(device: "torch.device", tensor_transport: Backend) -> bool:
48-
"""Check if the device matches the transport."""
49-
if tensor_transport == Backend.NIXL:
50-
return device.type == "cuda" or device.type == "cpu"
51-
elif tensor_transport == Backend.TORCH_GLOO:
52-
return device.type == "cpu"
53-
elif tensor_transport == Backend.NCCL:
54-
return device.type == "cuda"
55-
else:
56-
raise ValueError(f"Unsupported tensor transport protocol: {tensor_transport}")
57-
58-
5944
def find_free_port() -> int:
6045
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
6146
s.bind(("", 0))

python/ray/experimental/gpu_object_manager/gpu_object_store.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
TensorTransportMetadata,
1212
)
1313

14-
from ray.experimental.collective import get_tensor_transport_manager
15-
from ray.experimental.collective.util import device_match_transport
16-
1714
try:
1815
import torch
1916
except ImportError:
@@ -28,6 +25,14 @@
2825
TensorTransportEnum.NIXL: Backend.NIXL,
2926
}
3027

28+
COLLECTIVE_BACKEND_TO_TORCH_DEVICE = {
29+
Backend.NCCL: torch.device("cuda"),
30+
Backend.TORCH_GLOO: torch.device("cpu"),
31+
# TODO(Qiaolin-Yu): NIXL could also transfer tensors from CPU to CPU.
32+
# More details in https://github.com/ray-project/ray/issues/55587.
33+
Backend.NIXL: torch.device("cuda"),
34+
}
35+
3136

3237
def _tensor_transport_to_collective_backend(
3338
tensor_transport: TensorTransportEnum,
@@ -56,17 +61,15 @@ def __ray_send__(
5661
tensors = gpu_object_store.get_object(obj_id)
5762

5863
backend = collective.get_group_handle(communicator_meta.communicator_name).backend()
64+
device = COLLECTIVE_BACKEND_TO_TORCH_DEVICE[backend]
65+
66+
from ray.experimental.collective import get_tensor_transport_manager
5967

6068
tensor_transport_manager = get_tensor_transport_manager(backend)
61-
if tensors and not device_match_transport(
62-
tensor_transport_meta.tensor_device, backend
63-
):
64-
raise ValueError(
65-
f"Tensor transport backend {backend} does not support tensor transfer on device {tensor_transport_meta.tensor_device}."
66-
)
6769
tensor_transport_manager.send_multiple_tensors(
6870
tensors,
6971
communicator_meta,
72+
device=device,
7073
)
7174

7275

@@ -79,16 +82,14 @@ def __ray_recv__(
7982
"""Helper function that runs on the dst actor to receive tensors from the src actor."""
8083
from ray._private.worker import global_worker
8184

85+
from ray.experimental.collective import get_tensor_transport_manager
86+
8287
backend = collective.get_group_handle(communicator_meta.communicator_name).backend()
8388

84-
device = tensor_transport_meta.tensor_device
89+
device = COLLECTIVE_BACKEND_TO_TORCH_DEVICE[backend]
8590
tensor_meta = tensor_transport_meta.tensor_meta
8691

8792
gpu_object_store = global_worker.gpu_object_manager.gpu_object_store
88-
if tensor_meta and not device_match_transport(device, backend):
89-
raise ValueError(
90-
f"Tensor transport backend {backend} does not support tensor transfer on device {device}."
91-
)
9293
tensors = []
9394
for meta in tensor_meta:
9495
shape, dtype = meta

python/ray/tests/test_gpu_objects_nixl.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
@ray.remote(num_gpus=1, num_cpus=0, enable_tensor_transport=True)
88
class GPUTestActor:
99
@ray.method(tensor_transport="nixl")
10-
def echo(self, data, device):
11-
return data.to(device)
10+
def echo(self, data):
11+
return data.to("cuda")
1212

13-
def sum(self, data, device):
14-
assert data.device.type == device
13+
def sum(self, data):
1514
return data.sum().item()
1615

1716

@@ -24,21 +23,12 @@ def test_p2p(ray_start_regular):
2423

2524
# Create test tensor
2625
tensor = torch.tensor([1, 2, 3])
27-
28-
tensor1 = torch.tensor([4, 5, 6])
29-
30-
# Test GPU to GPU transfer
31-
ref = src_actor.echo.remote(tensor, "cuda")
26+
ref = src_actor.echo.remote(tensor)
3227

3328
# Trigger tensor transfer from src to dst actor
34-
result = dst_actor.sum.remote(ref, "cuda")
29+
result = dst_actor.sum.remote(ref)
3530
assert tensor.sum().item() == ray.get(result)
3631

37-
# Test CPU to CPU transfer
38-
ref1 = src_actor.echo.remote(tensor1, "cpu")
39-
result1 = dst_actor.sum.remote(ref1, "cpu")
40-
assert tensor1.sum().item() == ray.get(result1)
41-
4232

4333
@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 1}], indirect=True)
4434
def test_intra_gpu_tensor_transfer(ray_start_regular):
@@ -47,8 +37,8 @@ def test_intra_gpu_tensor_transfer(ray_start_regular):
4737
tensor = torch.tensor([1, 2, 3])
4838

4939
# Intra-actor communication for pure GPU tensors
50-
ref = actor.echo.remote(tensor, "cuda")
51-
result = actor.sum.remote(ref, "cuda")
40+
ref = actor.echo.remote(tensor)
41+
result = actor.sum.remote(ref)
5242
assert tensor.sum().item() == ray.get(result)
5343

5444

python/ray/util/collective/types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,9 @@ class TensorTransportMetadata:
6161
6262
Args:
6363
tensor_meta: A list of tuples, each containing the shape and dtype of a tensor.
64-
tensor_device: The device of the tensor. Currently, we require all tensors in the
65-
list have the same device type.
6664
"""
6765

6866
tensor_meta: List[Tuple["torch.Size", "torch.dtype"]]
69-
tensor_device: Optional["torch.device"] = None
7067

7168

7269
@dataclass

0 commit comments

Comments
 (0)