Skip to content

Commit 08b7fa1

Browse files
committed
UCT/MD: Add a new UCT memory query API
Unlike the existing uct_md_mem_query(), then new API is not specific to a single MD. Therefore, it does not need a handle to an opened MD.
1 parent 2f50844 commit 08b7fa1

File tree

4 files changed

+273
-52
lines changed

4 files changed

+273
-52
lines changed

src/uct/base/uct_md.c

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ ucs_config_field_t uct_md_config_rcache_table[] = {
4242
{NULL}
4343
};
4444

45+
static uct_mem_query_func_t mqf[UCS_MEMORY_TYPE_LAST];
46+
4547

4648
ucs_status_t uct_md_open(uct_component_h component, const char *md_name,
4749
const uct_md_config_t *config, uct_md_h *md_p)
@@ -489,3 +491,54 @@ ucs_status_t uct_md_detect_memory_type(uct_md_h md, const void *addr, size_t len
489491
{
490492
return md->ops->detect_memory_type(md, addr, length, mem_type_p);
491493
}
494+
495+
static int uct_mem_attr_cmp_host(uct_mem_attr_h mem_attr1,
496+
uct_mem_attr_h mem_attr2)
497+
{
498+
/* host memory attributes always compare equal */
499+
return 0;
500+
}
501+
502+
static void uct_mem_attr_destroy_host(uct_mem_attr_h mem_attr)
503+
{
504+
/* Nothing to be done for host memory attributes */
505+
}
506+
507+
/* all host memory will have the same attributes.
508+
* so, they will all point to this static struct */
509+
static uct_mem_attr_t mem_attr_host = {
510+
.mem_type = UCS_MEMORY_TYPE_HOST,
511+
.cmp = uct_mem_attr_cmp_host,
512+
.destroy = uct_mem_attr_destroy_host
513+
};
514+
515+
ucs_status_t uct_mem_attr_query(const void *address, size_t length,
516+
uct_mem_attr_h *mem_attr_p)
517+
{
518+
ucs_status_t status;
519+
ucs_memory_type_t mt;
520+
for (mt = UCS_MEMORY_TYPE_HOST + 1; mt < UCS_MEMORY_TYPE_LAST; mt++) {
521+
if (mqf[mt] != NULL) {
522+
status = mqf[mt](address, length, mem_attr_p);
523+
if (status == UCS_OK) {
524+
return UCS_OK;
525+
}
526+
}
527+
}
528+
529+
/* none of the MDs recognized the address. So, it must be HOST */
530+
*mem_attr_p = &mem_attr_host;
531+
return UCS_OK;
532+
}
533+
534+
ucs_status_t uct_mem_attr_query_type(const void *address, size_t length,
535+
ucs_memory_type_t *mem_type)
536+
{
537+
ucs_status_t status;
538+
uct_mem_attr_h mem_attr;
539+
status = uct_mem_attr_query(address, length, &mem_attr);
540+
if (status != UCS_OK) return UCS_ERR_NO_RESOURCE;
541+
*mem_type = mem_attr->mem_type;
542+
uct_mem_attr_destroy(mem_attr);
543+
return UCS_OK;
544+
}

src/uct/base/uct_md.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,4 +206,54 @@ static inline ucs_log_level_t uct_md_reg_log_lvl(unsigned flags)
206206
UCS_LOG_LEVEL_ERROR;
207207
}
208208

209+
typedef struct uct_mem_attr *uct_mem_attr_h;
210+
typedef struct uct_mem_attr {
211+
ucs_memory_type_t mem_type;
212+
ucs_sys_device_t sys_dev;
213+
int (*cmp)(uct_mem_attr_h mem_attr1, uct_mem_attr_h mem_attr2);
214+
void (*destroy)(uct_mem_attr_h mem_attr);
215+
} uct_mem_attr_t;
216+
217+
typedef ucs_status_t (*uct_mem_query_func_t)(const void *addr, size_t length,
218+
uct_mem_attr_h *mem_attr_p);
219+
220+
ucs_status_t uct_mem_attr_query(const void *address, size_t length,
221+
uct_mem_attr_h *mem_attr_p);
222+
223+
/* Getting the type directly from address and length */
224+
ucs_status_t uct_mem_attr_query_type(const void *address, size_t length,
225+
ucs_memory_type_t *mem_type);
226+
227+
static inline ucs_memory_type_t
228+
uct_mem_attr_get_type(uct_mem_attr_h mem_attr)
229+
{
230+
return mem_attr->mem_type;
231+
}
232+
233+
static inline ucs_sys_device_t
234+
uct_mem_attr_get_sys_dev(uct_mem_attr_h mem_attr)
235+
{
236+
return mem_attr->sys_dev;
237+
}
238+
239+
static inline int
240+
uct_mem_attr_cmp(uct_mem_attr_h mem_attr1, uct_mem_attr_h mem_attr2)
241+
{
242+
if (mem_attr1->mem_type == mem_attr2->mem_type) {
243+
return mem_attr1->cmp(mem_attr1, mem_attr2);
244+
}
245+
return 1;
246+
}
247+
248+
static inline void uct_mem_attr_destroy(uct_mem_attr_h mem_attr)
249+
{
250+
mem_attr->destroy(mem_attr);
251+
}
252+
253+
#define UCT_MEM_QUERY_REGISTER(_mem_query_func, _mem_type) \
254+
extern uct_mem_query_func_t *mqf; \
255+
UCS_STATIC_INIT { \
256+
mqf[_mem_type] = _mem_query_func; \
257+
}
258+
209259
#endif

src/uct/cuda/base/cuda_md.c

Lines changed: 111 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,91 @@
1818
#include <cuda.h>
1919

2020

21+
typedef struct uct_cuda_mem_attr {
22+
uct_mem_attr_t super;
23+
unsigned long long buf_id;
24+
} uct_cuda_mem_attr_t;
25+
26+
static int uct_cuda_mem_attr_cmp(const uct_mem_attr_h mem_attr1,
27+
const uct_mem_attr_h mem_attr2)
28+
{
29+
uct_cuda_mem_attr_t *cuda_mem_attr1, *cuda_mem_attr2;
30+
cuda_mem_attr1 = ucs_derived_of(mem_attr1, uct_cuda_mem_attr_t);
31+
cuda_mem_attr2 = ucs_derived_of(mem_attr2, uct_cuda_mem_attr_t);
32+
return cuda_mem_attr1->buf_id == cuda_mem_attr2->buf_id ? 0 : 1;
33+
}
34+
35+
static void uct_cuda_mem_attr_destroy(uct_mem_attr_h mem_attr)
36+
{
37+
uct_cuda_mem_attr_t *cuda_mem_attr;
38+
cuda_mem_attr = ucs_derived_of(mem_attr, uct_cuda_mem_attr_t);
39+
ucs_free(cuda_mem_attr);
40+
}
41+
42+
UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_mem_attr_query,
43+
(address, length, mem_attr_p),
44+
const void *address, size_t length,
45+
uct_mem_attr_h *mem_attr_p)
46+
{
47+
#define UCT_CUDA_MEM_QUERY_NUM_ATTRS 4
48+
CUmemorytype cuda_mem_mype = (CUmemorytype)0;
49+
unsigned long long buf_id = 0;
50+
uint32_t is_managed = 0;
51+
CUdevice cuda_device = -1;
52+
CUpointer_attribute attr_type[UCT_CUDA_MEM_QUERY_NUM_ATTRS];
53+
void *attr_data[UCT_CUDA_MEM_QUERY_NUM_ATTRS];
54+
ucs_memory_type_t mem_type;
55+
ucs_sys_device_t sys_dev;
56+
ucs_status_t status;
57+
CUresult cu_err;
58+
uct_cuda_mem_attr_t *cuda_mem_attr;
59+
60+
if (address == NULL) {
61+
return UCS_ERR_INVALID_ADDR;
62+
}
63+
64+
attr_type[0] = CU_POINTER_ATTRIBUTE_MEMORY_TYPE;
65+
attr_data[0] = &cuda_mem_mype;
66+
attr_type[1] = CU_POINTER_ATTRIBUTE_IS_MANAGED;
67+
attr_data[1] = &is_managed;
68+
attr_type[2] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
69+
attr_data[2] = &cuda_device;
70+
attr_type[3] = CU_POINTER_ATTRIBUTE_BUFFER_ID;
71+
attr_data[3] = &buf_id;
72+
73+
cu_err = cuPointerGetAttributes(ucs_static_array_size(attr_data),
74+
attr_type, attr_data,
75+
(CUdeviceptr)address);
76+
if ((cu_err != CUDA_SUCCESS) || (cuda_mem_mype != CU_MEMORYTYPE_DEVICE)) {
77+
/* pointer not recognized */
78+
return UCS_ERR_INVALID_ADDR;
79+
}
80+
81+
if (is_managed) {
82+
mem_type = UCS_MEMORY_TYPE_CUDA_MANAGED;
83+
} else {
84+
mem_type = UCS_MEMORY_TYPE_CUDA;
85+
}
86+
87+
status = uct_cuda_base_get_sys_dev(cuda_device, &sys_dev);
88+
if (status != UCS_OK) {
89+
return status;
90+
}
91+
92+
cuda_mem_attr = ucs_malloc(sizeof(*cuda_mem_attr), "cuda_mem_attr");
93+
if (cuda_mem_attr == NULL) {
94+
return UCS_ERR_NO_MEMORY;
95+
}
96+
cuda_mem_attr->buf_id = buf_id;
97+
cuda_mem_attr->super.mem_type = mem_type;
98+
cuda_mem_attr->super.sys_dev = sys_dev;
99+
cuda_mem_attr->super.cmp = uct_cuda_mem_attr_cmp;
100+
cuda_mem_attr->super.destroy = uct_cuda_mem_attr_destroy;
101+
102+
*mem_attr_p = &cuda_mem_attr->super;
103+
return UCS_OK;
104+
}
105+
21106
ucs_status_t uct_cuda_base_get_sys_dev(CUdevice cuda_device,
22107
ucs_sys_device_t *sys_dev_p)
23108
{
@@ -61,7 +146,8 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_detect_memory_type,
61146
uct_md_h md, const void *address, size_t length,
62147
ucs_memory_type_t *mem_type_p)
63148
{
64-
uct_md_mem_attr_t mem_attr;
149+
/* self-initializing to suppress wrong maybe-uninitialized error */
150+
uct_md_mem_attr_t mem_attr = mem_attr;
65151
ucs_status_t status;
66152

67153
mem_attr.field_mask = UCT_MD_MEM_ATTR_FIELD_MEM_TYPE;
@@ -80,72 +166,43 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_mem_query,
80166
uct_md_h md, const void *address, size_t length,
81167
uct_md_mem_attr_t *mem_attr)
82168
{
83-
#define UCT_CUDA_MEM_QUERY_NUM_ATTRS 3
84-
CUmemorytype cuda_mem_mype = (CUmemorytype)0;
85-
uint32_t is_managed = 0;
86-
unsigned value = 1;
87-
CUdevice cuda_device = -1;
88-
CUpointer_attribute attr_type[UCT_CUDA_MEM_QUERY_NUM_ATTRS];
89-
void *attr_data[UCT_CUDA_MEM_QUERY_NUM_ATTRS];
90-
ucs_memory_type_t mem_type;
91-
const char *cu_err_str;
169+
uct_mem_attr_h mem_attr_h;
92170
ucs_status_t status;
93-
CUresult cu_err;
94171

95172
if (!(mem_attr->field_mask & (UCT_MD_MEM_ATTR_FIELD_MEM_TYPE |
96173
UCT_MD_MEM_ATTR_FIELD_SYS_DEV))) {
97174
return UCS_OK;
98175
}
99176

100-
if (address == NULL) {
101-
mem_type = UCS_MEMORY_TYPE_HOST;
102-
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV) {
103-
mem_attr->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
104-
}
105-
} else {
106-
attr_type[0] = CU_POINTER_ATTRIBUTE_MEMORY_TYPE;
107-
attr_data[0] = &cuda_mem_mype;
108-
attr_type[1] = CU_POINTER_ATTRIBUTE_IS_MANAGED;
109-
attr_data[1] = &is_managed;
110-
attr_type[2] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
111-
attr_data[2] = &cuda_device;
112-
113-
cu_err = cuPointerGetAttributes(ucs_static_array_size(attr_data),
114-
attr_type, attr_data,
115-
(CUdeviceptr)address);
116-
if ((cu_err != CUDA_SUCCESS) || (cuda_mem_mype != CU_MEMORYTYPE_DEVICE)) {
117-
/* pointer not recognized */
118-
return UCS_ERR_INVALID_ADDR;
119-
}
177+
status = uct_cuda_mem_attr_query(address, length, &mem_attr_h);
178+
if (status != UCS_OK) {
179+
return status;
180+
}
120181

121-
if (is_managed) {
122-
mem_type = UCS_MEMORY_TYPE_CUDA_MANAGED;
123-
} else {
124-
mem_type = UCS_MEMORY_TYPE_CUDA;
125-
126-
/* Synchronize for DMA */
127-
cu_err = cuPointerSetAttribute(&value,
128-
CU_POINTER_ATTRIBUTE_SYNC_MEMOPS,
129-
(CUdeviceptr)address);
130-
if (cu_err != CUDA_SUCCESS) {
131-
cuGetErrorString(cu_err, &cu_err_str);
132-
ucs_warn("cuPointerSetAttribute(%p) error: %s", address,
133-
cu_err_str);
134-
}
182+
if (uct_mem_attr_get_type(mem_attr_h) == UCS_MEMORY_TYPE_CUDA) {
183+
unsigned value = 1;
184+
CUresult cu_err;
185+
const char *cu_err_str;
186+
/* Synchronize for DMA */
187+
cu_err = cuPointerSetAttribute(&value,
188+
CU_POINTER_ATTRIBUTE_SYNC_MEMOPS,
189+
(CUdeviceptr)address);
190+
if (cu_err != CUDA_SUCCESS) {
191+
cuGetErrorString(cu_err, &cu_err_str);
192+
ucs_warn("cuPointerSetAttribute(%p) error: %s", address,
193+
cu_err_str);
135194
}
195+
}
136196

137-
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV) {
138-
status = uct_cuda_base_get_sys_dev(cuda_device, &mem_attr->sys_dev);
139-
if (status != UCS_OK) {
140-
return status;
141-
}
142-
}
197+
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV) {
198+
mem_attr->sys_dev = uct_mem_attr_get_sys_dev(mem_attr_h);
143199
}
144200

145201
if (mem_attr->field_mask & UCT_MD_MEM_ATTR_FIELD_MEM_TYPE) {
146-
mem_attr->mem_type = mem_type;
202+
mem_attr->mem_type = uct_mem_attr_get_type(mem_attr_h);
147203
}
148204

205+
uct_mem_attr_destroy(mem_attr_h);
149206
return UCS_OK;
150207
}
151208

@@ -172,3 +229,5 @@ UCS_MODULE_INIT() {
172229
UCS_MODULE_FRAMEWORK_LOAD(uct_cuda, 0);
173230
return UCS_OK;
174231
}
232+
233+
UCT_MEM_QUERY_REGISTER(uct_cuda_mem_attr_query, UCS_MEMORY_TYPE_CUDA);

src/uct/rocm/base/rocm_base.c

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,67 @@ ucs_status_t uct_rocm_base_mem_query(uct_md_h md, const void *addr,
223223
return UCS_OK;
224224
}
225225

226+
typedef struct uct_rocm_mem_attr {
227+
uct_mem_attr_t super;
228+
/* TODO what rocm memory attributes do we need? */
229+
} uct_rocm_mem_attr_t;
230+
231+
static int uct_rocm_mem_attr_cmp(const uct_mem_attr_h mem_attr1,
232+
const uct_mem_attr_h mem_attr2)
233+
{
234+
/* TODO how should we compare two rocm memory attributes? */
235+
return 0;
236+
}
237+
238+
static void uct_rocm_mem_attr_destroy(uct_mem_attr_h mem_attr)
239+
{
240+
uct_rocm_mem_attr_t *rocm_mem_attr;
241+
rocm_mem_attr = ucs_derived_of(mem_attr, uct_rocm_mem_attr_t);
242+
ucs_free(rocm_mem_attr);
243+
}
244+
245+
static ucs_status_t ucm_rocmmem_attr_get(const void *address, size_t length,
246+
uct_mem_attr_h *mem_attr_p)
247+
{
248+
hsa_status_t status;
249+
hsa_amd_pointer_info_t info = {
250+
.size = sizeof(hsa_amd_pointer_info_t),
251+
};
252+
uct_rocm_mem_attr_t *rocm_mem_attr;
253+
254+
if (address == NULL) {
255+
return UCS_ERR_INVALID_ADDR;
256+
}
257+
258+
status = hsa_amd_pointer_info((void*)address, &info, NULL, NULL, NULL);
259+
if ((status != HSA_STATUS_SUCCESS) ||
260+
(info.type != HSA_EXT_POINTER_TYPE_HSA)) {
261+
return UCS_ERR_INVALID_ADDR;
262+
}
263+
264+
hsa_device_type_t dev_type;
265+
status = hsa_agent_get_info(info.agentOwner, HSA_AGENT_INFO_DEVICE, &dev_type);
266+
if ((status != HSA_STATUS_SUCCESS) ||
267+
(dev_type != HSA_DEVICE_TYPE_GPU)) {
268+
return UCS_ERR_INVALID_ADDR;
269+
}
270+
271+
rocm_mem_attr = ucs_malloc(sizeof(*rocm_mem_attr), "rocmmem_attr");
272+
if (rocm_mem_attr == NULL) {
273+
return UCS_ERR_NO_MEMORY;
274+
}
275+
rocm_mem_attr->super.mem_type = UCS_MEMORY_TYPE_ROCM;
276+
rocm_mem_attr->super.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
277+
rocm_mem_attr->super.cmp = uct_rocm_mem_attr_cmp;
278+
rocm_mem_attr->super.destroy = uct_rocm_mem_attr_destroy;
279+
280+
return UCS_OK;
281+
}
282+
226283
UCS_MODULE_INIT() {
227284
UCS_MODULE_FRAMEWORK_DECLARE(uct_rocm);
228285
UCS_MODULE_FRAMEWORK_LOAD(uct_rocm, 0);
229286
return UCS_OK;
230287
}
288+
289+
UCT_MEM_QUERY_REGISTER(uct_rocm_mem_attr_query, UCS_MEMORY_TYPE_ROCM);

0 commit comments

Comments
 (0)