Skip to content

Commit d8b5beb

Browse files
authored
Fix2 select_device() for Multi-GPU (#6461)
* Fix2 select_device() for Multi-GPU * Cleanup * Cleanup * Simplify error message * Improve assert * Update torch_utils.py
1 parent 856d4e5 commit d8b5beb

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

utils/datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,12 @@
2929
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
3030
from utils.general import (LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
3131
segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
32-
from utils.torch_utils import device_count, torch_distributed_zero_first
32+
from utils.torch_utils import torch_distributed_zero_first
3333

3434
# Parameters
3535
HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
3636
IMG_FORMATS = ['bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp'] # include image suffixes
3737
VID_FORMATS = ['asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'wmv'] # include video suffixes
38-
DEVICE_COUNT = max(device_count(), 1) # number of CUDA devices
3938

4039
# Get orientation exif tag
4140
for orientation in ExifTags.TAGS.keys():
@@ -110,7 +109,8 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
110109
prefix=prefix)
111110

112111
batch_size = min(batch_size, len(dataset))
113-
nw = min([os.cpu_count() // DEVICE_COUNT, batch_size if batch_size > 1 else 0, workers]) # number of workers
112+
nd = torch.cuda.device_count() # number of CUDA devices
113+
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
114114
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
115115
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
116116
return loader(dataset,

utils/torch_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def git_describe(path=Path(__file__).parent): # path must be a directory
5454

5555

5656
def device_count():
57-
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count().
57+
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux.
58+
assert platform.system() == 'Linux', 'device_count() function only works on Linux'
5859
try:
5960
cmd = 'nvidia-smi -L | wc -l'
6061
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
@@ -70,10 +71,9 @@ def select_device(device='', batch_size=0, newline=True):
7071
if cpu:
7172
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
7273
elif device: # non-cpu device requested
73-
nd = device_count() # number of CUDA devices
74-
assert nd > int(max(device.split(','))), f'Invalid `--device {device}` request, valid devices are 0 - {nd - 1}'
7574
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
76-
assert torch.cuda.is_available(), 'CUDA is not available, use `--device cpu` or do not pass a --device'
75+
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
76+
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
7777

7878
cuda = not cpu and torch.cuda.is_available()
7979
if cuda:

0 commit comments

Comments
 (0)