Skip to content

Commit eec86cc

Browse files
committed
UCT/MD: Add a new UCT memory query API
Unlike the existing uct_md_mem_query(), then new API is not specific to a single MD. Therefore, it does not need a handle to an opened MD.
1 parent ae3b47b commit eec86cc

File tree

4 files changed

+294
-69
lines changed

4 files changed

+294
-69
lines changed

src/uct/base/uct_md.c

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ ucs_config_field_t uct_md_config_rcache_table[] = {
5252
{NULL}
5353
};
5454

55+
uct_mem_query_func_t mqf[UCS_MEMORY_TYPE_LAST];
56+
5557

5658
ucs_status_t uct_md_open(uct_component_h component, const char *md_name,
5759
const uct_md_config_t *config, uct_md_h *md_p)
@@ -494,3 +496,54 @@ void uct_md_set_rcache_params(ucs_rcache_params_t *rcache_params,
494496
rcache_params->max_regions = rcache_config->max_regions;
495497
rcache_params->max_size = rcache_config->max_size;
496498
}
499+
500+
static int uct_mem_attr_cmp_host(uct_mem_attr_h mem_attr1,
501+
uct_mem_attr_h mem_attr2)
502+
{
503+
/* host memory attributes always compare equal */
504+
return 0;
505+
}
506+
507+
static void uct_mem_attr_destroy_host(uct_mem_attr_h mem_attr)
508+
{
509+
/* Nothing to be done for host memory attributes */
510+
}
511+
512+
/* all host memory will have the same attributes.
513+
* so, they will all point to this static struct */
514+
static uct_mem_attr_t mem_attr_host = {
515+
.mem_type = UCS_MEMORY_TYPE_HOST,
516+
.cmp = uct_mem_attr_cmp_host,
517+
.destroy = uct_mem_attr_destroy_host
518+
};
519+
520+
ucs_status_t uct_mem_attr_query(const void *address, size_t length,
521+
uct_mem_attr_h *mem_attr_p)
522+
{
523+
ucs_status_t status;
524+
ucs_memory_type_t mt;
525+
for (mt = UCS_MEMORY_TYPE_HOST + 1; mt < UCS_MEMORY_TYPE_LAST; mt++) {
526+
if (mqf[mt] != NULL) {
527+
status = mqf[mt](address, length, mem_attr_p);
528+
if (status == UCS_OK) {
529+
return UCS_OK;
530+
}
531+
}
532+
}
533+
534+
/* none of the MDs recognized the address. So, it must be HOST */
535+
*mem_attr_p = &mem_attr_host;
536+
return UCS_OK;
537+
}
538+
539+
ucs_status_t uct_mem_attr_query_type(const void *address, size_t length,
540+
ucs_memory_type_t *mem_type)
541+
{
542+
ucs_status_t status;
543+
uct_mem_attr_h mem_attr;
544+
status = uct_mem_attr_query(address, length, &mem_attr);
545+
if (status != UCS_OK) return UCS_ERR_NO_RESOURCE;
546+
*mem_type = mem_attr->mem_type;
547+
uct_mem_attr_destroy(mem_attr);
548+
return UCS_OK;
549+
}

src/uct/base/uct_md.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,4 +214,54 @@ static inline ucs_log_level_t uct_md_reg_log_lvl(unsigned flags)
214214
UCS_LOG_LEVEL_ERROR;
215215
}
216216

217+
typedef struct uct_mem_attr *uct_mem_attr_h;
218+
typedef struct uct_mem_attr {
219+
ucs_memory_type_t mem_type;
220+
ucs_sys_device_t sys_dev;
221+
int (*cmp)(uct_mem_attr_h mem_attr1, uct_mem_attr_h mem_attr2);
222+
void (*destroy)(uct_mem_attr_h mem_attr);
223+
} uct_mem_attr_t;
224+
225+
typedef ucs_status_t (*uct_mem_query_func_t)(const void *addr, size_t length,
226+
uct_mem_attr_h *mem_attr_p);
227+
228+
ucs_status_t uct_mem_attr_query(const void *address, size_t length,
229+
uct_mem_attr_h *mem_attr_p);
230+
231+
/* Getting the type directly from address and length */
232+
ucs_status_t uct_mem_attr_query_type(const void *address, size_t length,
233+
ucs_memory_type_t *mem_type);
234+
235+
static inline ucs_memory_type_t
236+
uct_mem_attr_get_type(uct_mem_attr_h mem_attr)
237+
{
238+
return mem_attr->mem_type;
239+
}
240+
241+
static inline ucs_sys_device_t
242+
uct_mem_attr_get_sys_dev(uct_mem_attr_h mem_attr)
243+
{
244+
return mem_attr->sys_dev;
245+
}
246+
247+
static inline int
248+
uct_mem_attr_cmp(uct_mem_attr_h mem_attr1, uct_mem_attr_h mem_attr2)
249+
{
250+
if (mem_attr1->mem_type == mem_attr2->mem_type) {
251+
return mem_attr1->cmp(mem_attr1, mem_attr2);
252+
}
253+
return 1;
254+
}
255+
256+
static inline void uct_mem_attr_destroy(uct_mem_attr_h mem_attr)
257+
{
258+
mem_attr->destroy(mem_attr);
259+
}
260+
261+
#define UCT_MEM_QUERY_REGISTER(_mem_query_func, _mem_type) \
262+
extern uct_mem_query_func_t mqf[]; \
263+
UCS_STATIC_INIT { \
264+
mqf[_mem_type] = _mem_query_func; \
265+
}
266+
217267
#endif

src/uct/cuda/base/cuda_md.c

Lines changed: 132 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,108 @@
1818
#include <cuda.h>
1919

2020

21+
typedef struct uct_cuda_mem_attr {
22+
uct_mem_attr_t super;
23+
unsigned long long buf_id;
24+
void *base_address;
25+
size_t alloc_length;
26+
} uct_cuda_mem_attr_t;
27+
28+
static int uct_cuda_mem_attr_cmp(const uct_mem_attr_h mem_attr1,
29+
const uct_mem_attr_h mem_attr2)
30+
{
31+
uct_cuda_mem_attr_t *cuda_mem_attr1, *cuda_mem_attr2;
32+
cuda_mem_attr1 = ucs_derived_of(mem_attr1, uct_cuda_mem_attr_t);
33+
cuda_mem_attr2 = ucs_derived_of(mem_attr2, uct_cuda_mem_attr_t);
34+
return cuda_mem_attr1->buf_id == cuda_mem_attr2->buf_id ? 0 : 1;
35+
}
36+
37+
static void uct_cuda_mem_attr_destroy(uct_mem_attr_h mem_attr)
38+
{
39+
uct_cuda_mem_attr_t *cuda_mem_attr;
40+
cuda_mem_attr = ucs_derived_of(mem_attr, uct_cuda_mem_attr_t);
41+
ucs_free(cuda_mem_attr);
42+
}
43+
44+
UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_mem_attr_query,
45+
(address, length, mem_attr_p),
46+
const void *address, size_t length,
47+
uct_mem_attr_h *mem_attr_p)
48+
{
49+
#define UCT_CUDA_MEM_QUERY_NUM_ATTRS 4
50+
CUmemorytype cuda_mem_mype = (CUmemorytype)0;
51+
unsigned long long buf_id = 0;
52+
uint32_t is_managed = 0;
53+
CUdevice cuda_device = -1;
54+
void *base_address = (void*)address;
55+
size_t alloc_length = length;
56+
ucs_sys_device_t sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
57+
CUpointer_attribute attr_type[UCT_CUDA_MEM_QUERY_NUM_ATTRS];
58+
void *attr_data[UCT_CUDA_MEM_QUERY_NUM_ATTRS];
59+
ucs_memory_type_t mem_type;
60+
uct_cuda_mem_attr_t *cuda_mem_attr;
61+
CUresult cu_err;
62+
const char *cu_err_str;
63+
ucs_status_t status;
64+
65+
if (address == NULL) {
66+
return UCS_ERR_INVALID_ADDR;
67+
}
68+
69+
attr_type[0] = CU_POINTER_ATTRIBUTE_MEMORY_TYPE;
70+
attr_data[0] = &cuda_mem_mype;
71+
attr_type[1] = CU_POINTER_ATTRIBUTE_IS_MANAGED;
72+
attr_data[1] = &is_managed;
73+
attr_type[2] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
74+
attr_data[2] = &cuda_device;
75+
attr_type[3] = CU_POINTER_ATTRIBUTE_BUFFER_ID;
76+
attr_data[3] = &buf_id;
77+
78+
cu_err = cuPointerGetAttributes(ucs_static_array_size(attr_data),
79+
attr_type, attr_data,
80+
(CUdeviceptr)address);
81+
if ((cu_err != CUDA_SUCCESS) || (cuda_mem_mype != CU_MEMORYTYPE_DEVICE)) {
82+
/* pointer not recognized */
83+
return UCS_ERR_INVALID_ADDR;
84+
}
85+
86+
if (is_managed) {
87+
mem_type = UCS_MEMORY_TYPE_CUDA_MANAGED;
88+
} else {
89+
mem_type = UCS_MEMORY_TYPE_CUDA;
90+
}
91+
92+
status = uct_cuda_base_get_sys_dev(cuda_device, &sys_dev);
93+
if (status != UCS_OK) {
94+
return status;
95+
}
96+
97+
cu_err = cuMemGetAddressRange((CUdeviceptr*)&base_address,
98+
&alloc_length, (CUdeviceptr)address);
99+
if (cu_err != CUDA_SUCCESS) {
100+
cuGetErrorString(cu_err, &cu_err_str);
101+
ucs_error("ccuMemGetAddressRange(%p) error: %s", address,
102+
cu_err_str);
103+
return UCS_ERR_INVALID_ADDR;
104+
}
105+
106+
cuda_mem_attr = ucs_malloc(sizeof(*cuda_mem_attr), "cuda_mem_attr");
107+
if (cuda_mem_attr == NULL) {
108+
return UCS_ERR_NO_MEMORY;
109+
}
110+
111+
cuda_mem_attr->buf_id = buf_id;
112+
cuda_mem_attr->base_address = base_address;
113+
cuda_mem_attr->alloc_length = alloc_length;
114+
cuda_mem_attr->super.mem_type = mem_type;
115+
cuda_mem_attr->super.sys_dev = sys_dev;
116+
cuda_mem_attr->super.cmp = uct_cuda_mem_attr_cmp;
117+
cuda_mem_attr->super.destroy = uct_cuda_mem_attr_destroy;
118+
119+
*mem_attr_p = &cuda_mem_attr->super;
120+
return UCS_OK;
121+
}
122+
21123
ucs_status_t uct_cuda_base_get_sys_dev(CUdevice cuda_device,
22124
ucs_sys_device_t *sys_dev_p)
23125
{
@@ -61,7 +163,8 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_detect_memory_type,
61163
uct_md_h md, const void *address, size_t length,
62164
ucs_memory_type_t *mem_type_p)
63165
{
64-
uct_md_mem_attr_t mem_attr;
166+
/* self-initializing to suppress wrong maybe-uninitialized error */
167+
uct_md_mem_attr_t mem_attr = mem_attr;
65168
ucs_status_t status;
66169

67170
mem_attr.field_mask = UCT_MD_MEM_ATTR_FIELD_MEM_TYPE;
@@ -80,20 +183,9 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_mem_query,
80183
uct_md_h md, const void *address, size_t length,
81184
uct_md_mem_attr_t *mem_attr)
82185
{
83-
#define UCT_CUDA_MEM_QUERY_NUM_ATTRS 3
84-
CUmemorytype cuda_mem_mype = (CUmemorytype)0;
85-
uint32_t is_managed = 0;
86-
unsigned value = 1;
87-
CUdevice cuda_device = -1;
88-
void *base_address = (void*)address;
89-
size_t alloc_length = length;
90-
ucs_sys_device_t sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
91-
CUpointer_attribute attr_type[UCT_CUDA_MEM_QUERY_NUM_ATTRS];
92-
void *attr_data[UCT_CUDA_MEM_QUERY_NUM_ATTRS];
93-
ucs_memory_type_t mem_type;
94-
const char *cu_err_str;
186+
uct_mem_attr_h mem_attr_h;
187+
uct_cuda_mem_attr_t *cuda_mem_attr;
95188
ucs_status_t status;
96-
CUresult cu_err;
97189

98190
if (!(mem_attr->field_mask & (UCT_MD_MEM_ATTR_FIELD_MEM_TYPE |
99191
UCT_MD_MEM_ATTR_FIELD_SYS_DEV |
@@ -102,76 +194,45 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_mem_query,
102194
return UCS_OK;
103195
}
104196

105-
if (address == NULL) {
106-
mem_type = UCS_MEMORY_TYPE_HOST;
107-
} else {
108-
attr_type[0] = CU_POINTER_ATTRIBUTE_MEMORY_TYPE;
109-
attr_data[0] = &cuda_mem_mype;
110-
attr_type[1] = CU_POINTER_ATTRIBUTE_IS_MANAGED;
111-
attr_data[1] = &is_managed;
112-
attr_type[2] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
113-
attr_data[2] = &cuda_device;
114-
115-
cu_err = cuPointerGetAttributes(ucs_static_array_size(attr_data),
116-
attr_type, attr_data,
117-
(CUdeviceptr)address);
118-
if ((cu_err != CUDA_SUCCESS) || (cuda_mem_mype != CU_MEMORYTYPE_DEVICE)) {
119-
/* pointer not recognized */
120-
return UCS_ERR_INVALID_ADDR;
121-
}
197+
status = uct_cuda_mem_attr_query(address, length, &mem_attr_h);
198+
if (status != UCS_OK) {
199+
return status;
200+
}
122201

123-
if (is_managed) {
124-
mem_type = UCS_MEMORY_TYPE_CUDA_MANAGED;
125-
} else {
126-
mem_type = UCS_MEMORY_TYPE_CUDA;
127-
128-
/* Synchronize for DMA */
129-
cu_err = cuPointerSetAttribute(&value,
130-
CU_POINTER_ATTRIBUTE_SYNC_MEMOPS,
131-
(CUdeviceptr)address);
132-
if (cu_err != CUDA_SUCCESS) {
133-
cuGetErrorString(cu_err, &cu_err_str);
134-
ucs_warn("cuPointerSetAttribute(%p) error: %s", address,
135-
cu_err_str);
136-
}
202+
if (uct_mem_attr_get_type(mem_attr_h) == UCS_MEMORY_TYPE_CUDA) {
203+
unsigned value = 1;
204+
CUresult cu_err;
205+
const char *cu_err_str;
206+
/* Synchronize for DMA */
207+
cu_err = cuPointerSetAttribute(&value,
208+
CU_POINTER_ATTRIBUTE_SYNC_MEMOPS,
209+
(CUdeviceptr)address);
210+
if (cu_err != CUDA_SUCCESS) {
211+
cuGetErrorString(cu_err, &cu_err_str);
212+
ucs_warn("cuPointerSetAttribute(%p) error: %s", address,
213+
cu_err_str);
137214
}
215+
}
138216

139-
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV) {
140-
status = uct_cuda_base_get_sys_dev(cuda_device, &sys_dev);
141-
if (status != UCS_OK) {
142-
return status;
143-
}
144-
}
217+
cuda_mem_attr = ucs_derived_of(mem_attr_h, uct_cuda_mem_attr_t);
145218

146-
if (mem_attr->field_mask & (UCT_MD_MEM_ATTR_FIELD_ALLOC_LENGTH |
147-
UCT_MD_MEM_ATTR_FIELD_BASE_ADDRESS)) {
148-
cu_err = cuMemGetAddressRange((CUdeviceptr*)&base_address,
149-
&alloc_length, (CUdeviceptr)address);
150-
if (cu_err != CUDA_SUCCESS) {
151-
cuGetErrorString(cu_err, &cu_err_str);
152-
ucs_error("ccuMemGetAddressRange(%p) error: %s", address,
153-
cu_err_str);
154-
return UCS_ERR_INVALID_ADDR;
155-
}
156-
}
219+
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV) {
220+
mem_attr->sys_dev = cuda_mem_attr->super.sys_dev;
157221
}
158222

159223
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_MEM_TYPE) {
160-
mem_attr->mem_type = mem_type;
161-
}
162-
163-
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV) {
164-
mem_attr->sys_dev = sys_dev;
224+
mem_attr->mem_type = cuda_mem_attr->super.mem_type;
165225
}
166226

167227
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_BASE_ADDRESS) {
168-
mem_attr->base_address = base_address;
228+
mem_attr->base_address = cuda_mem_attr->base_address;
169229
}
170230

171231
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_ALLOC_LENGTH) {
172-
mem_attr->alloc_length = alloc_length;
232+
mem_attr->alloc_length = cuda_mem_attr->alloc_length;
173233
}
174234

235+
uct_mem_attr_destroy(mem_attr_h);
175236
return UCS_OK;
176237
}
177238

@@ -198,3 +259,5 @@ UCS_MODULE_INIT() {
198259
UCS_MODULE_FRAMEWORK_LOAD(uct_cuda, 0);
199260
return UCS_OK;
200261
}
262+
263+
UCT_MEM_QUERY_REGISTER(uct_cuda_mem_attr_query, UCS_MEMORY_TYPE_CUDA);

0 commit comments

Comments
 (0)