1818#include <cuda.h>
1919
2020
21+ typedef struct uct_cuda_mem_attr {
22+ uct_mem_attr_t super ;
23+ unsigned long long buf_id ;
24+ void * base_address ;
25+ size_t alloc_length ;
26+ } uct_cuda_mem_attr_t ;
27+
28+ static int uct_cuda_mem_attr_cmp (const uct_mem_attr_h mem_attr1 ,
29+ const uct_mem_attr_h mem_attr2 )
30+ {
31+ uct_cuda_mem_attr_t * cuda_mem_attr1 , * cuda_mem_attr2 ;
32+ cuda_mem_attr1 = ucs_derived_of (mem_attr1 , uct_cuda_mem_attr_t );
33+ cuda_mem_attr2 = ucs_derived_of (mem_attr2 , uct_cuda_mem_attr_t );
34+ return cuda_mem_attr1 -> buf_id == cuda_mem_attr2 -> buf_id ? 0 : 1 ;
35+ }
36+
37+ static void uct_cuda_mem_attr_destroy (uct_mem_attr_h mem_attr )
38+ {
39+ uct_cuda_mem_attr_t * cuda_mem_attr ;
40+ cuda_mem_attr = ucs_derived_of (mem_attr , uct_cuda_mem_attr_t );
41+ ucs_free (cuda_mem_attr );
42+ }
43+
44+ UCS_PROFILE_FUNC (ucs_status_t , uct_cuda_mem_attr_query ,
45+ (address , length , mem_attr_p ),
46+ const void * address , size_t length ,
47+ uct_mem_attr_h * mem_attr_p )
48+ {
49+ #define UCT_CUDA_MEM_QUERY_NUM_ATTRS 4
50+ CUmemorytype cuda_mem_mype = (CUmemorytype )0 ;
51+ unsigned long long buf_id = 0 ;
52+ uint32_t is_managed = 0 ;
53+ CUdevice cuda_device = -1 ;
54+ void * base_address = (void * )address ;
55+ size_t alloc_length = length ;
56+ ucs_sys_device_t sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN ;
57+ CUpointer_attribute attr_type [UCT_CUDA_MEM_QUERY_NUM_ATTRS ];
58+ void * attr_data [UCT_CUDA_MEM_QUERY_NUM_ATTRS ];
59+ ucs_memory_type_t mem_type ;
60+ uct_cuda_mem_attr_t * cuda_mem_attr ;
61+ CUresult cu_err ;
62+ const char * cu_err_str ;
63+ ucs_status_t status ;
64+
65+ if (address == NULL ) {
66+ return UCS_ERR_INVALID_ADDR ;
67+ }
68+
69+ attr_type [0 ] = CU_POINTER_ATTRIBUTE_MEMORY_TYPE ;
70+ attr_data [0 ] = & cuda_mem_mype ;
71+ attr_type [1 ] = CU_POINTER_ATTRIBUTE_IS_MANAGED ;
72+ attr_data [1 ] = & is_managed ;
73+ attr_type [2 ] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL ;
74+ attr_data [2 ] = & cuda_device ;
75+ attr_type [3 ] = CU_POINTER_ATTRIBUTE_BUFFER_ID ;
76+ attr_data [3 ] = & buf_id ;
77+
78+ cu_err = cuPointerGetAttributes (ucs_static_array_size (attr_data ),
79+ attr_type , attr_data ,
80+ (CUdeviceptr )address );
81+ if ((cu_err != CUDA_SUCCESS ) || (cuda_mem_mype != CU_MEMORYTYPE_DEVICE )) {
82+ /* pointer not recognized */
83+ return UCS_ERR_INVALID_ADDR ;
84+ }
85+
86+ if (is_managed ) {
87+ mem_type = UCS_MEMORY_TYPE_CUDA_MANAGED ;
88+ } else {
89+ mem_type = UCS_MEMORY_TYPE_CUDA ;
90+ }
91+
92+ status = uct_cuda_base_get_sys_dev (cuda_device , & sys_dev );
93+ if (status != UCS_OK ) {
94+ return status ;
95+ }
96+
97+ cu_err = cuMemGetAddressRange ((CUdeviceptr * )& base_address ,
98+ & alloc_length , (CUdeviceptr )address );
99+ if (cu_err != CUDA_SUCCESS ) {
100+ cuGetErrorString (cu_err , & cu_err_str );
101+ ucs_error ("ccuMemGetAddressRange(%p) error: %s" , address ,
102+ cu_err_str );
103+ return UCS_ERR_INVALID_ADDR ;
104+ }
105+
106+ cuda_mem_attr = ucs_malloc (sizeof (* cuda_mem_attr ), "cuda_mem_attr" );
107+ if (cuda_mem_attr == NULL ) {
108+ return UCS_ERR_NO_MEMORY ;
109+ }
110+
111+ cuda_mem_attr -> buf_id = buf_id ;
112+ cuda_mem_attr -> base_address = base_address ;
113+ cuda_mem_attr -> alloc_length = alloc_length ;
114+ cuda_mem_attr -> super .mem_type = mem_type ;
115+ cuda_mem_attr -> super .sys_dev = sys_dev ;
116+ cuda_mem_attr -> super .cmp = uct_cuda_mem_attr_cmp ;
117+ cuda_mem_attr -> super .destroy = uct_cuda_mem_attr_destroy ;
118+
119+ * mem_attr_p = & cuda_mem_attr -> super ;
120+ return UCS_OK ;
121+ }
122+
21123ucs_status_t uct_cuda_base_get_sys_dev (CUdevice cuda_device ,
22124 ucs_sys_device_t * sys_dev_p )
23125{
@@ -61,7 +163,8 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_detect_memory_type,
61163 uct_md_h md , const void * address , size_t length ,
62164 ucs_memory_type_t * mem_type_p )
63165{
64- uct_md_mem_attr_t mem_attr ;
166+ /* self-initializing to suppress wrong maybe-uninitialized error */
167+ uct_md_mem_attr_t mem_attr = mem_attr ;
65168 ucs_status_t status ;
66169
67170 mem_attr .field_mask = UCT_MD_MEM_ATTR_FIELD_MEM_TYPE ;
@@ -80,20 +183,9 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_mem_query,
80183 uct_md_h md , const void * address , size_t length ,
81184 uct_md_mem_attr_t * mem_attr )
82185{
83- #define UCT_CUDA_MEM_QUERY_NUM_ATTRS 3
84- CUmemorytype cuda_mem_mype = (CUmemorytype )0 ;
85- uint32_t is_managed = 0 ;
86- unsigned value = 1 ;
87- CUdevice cuda_device = -1 ;
88- void * base_address = (void * )address ;
89- size_t alloc_length = length ;
90- ucs_sys_device_t sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN ;
91- CUpointer_attribute attr_type [UCT_CUDA_MEM_QUERY_NUM_ATTRS ];
92- void * attr_data [UCT_CUDA_MEM_QUERY_NUM_ATTRS ];
93- ucs_memory_type_t mem_type ;
94- const char * cu_err_str ;
186+ uct_mem_attr_h mem_attr_h ;
187+ uct_cuda_mem_attr_t * cuda_mem_attr ;
95188 ucs_status_t status ;
96- CUresult cu_err ;
97189
98190 if (!(mem_attr -> field_mask & (UCT_MD_MEM_ATTR_FIELD_MEM_TYPE |
99191 UCT_MD_MEM_ATTR_FIELD_SYS_DEV |
@@ -102,76 +194,45 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_mem_query,
102194 return UCS_OK ;
103195 }
104196
105- if (address == NULL ) {
106- mem_type = UCS_MEMORY_TYPE_HOST ;
107- } else {
108- attr_type [0 ] = CU_POINTER_ATTRIBUTE_MEMORY_TYPE ;
109- attr_data [0 ] = & cuda_mem_mype ;
110- attr_type [1 ] = CU_POINTER_ATTRIBUTE_IS_MANAGED ;
111- attr_data [1 ] = & is_managed ;
112- attr_type [2 ] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL ;
113- attr_data [2 ] = & cuda_device ;
114-
115- cu_err = cuPointerGetAttributes (ucs_static_array_size (attr_data ),
116- attr_type , attr_data ,
117- (CUdeviceptr )address );
118- if ((cu_err != CUDA_SUCCESS ) || (cuda_mem_mype != CU_MEMORYTYPE_DEVICE )) {
119- /* pointer not recognized */
120- return UCS_ERR_INVALID_ADDR ;
121- }
197+ status = uct_cuda_mem_attr_query (address , length , & mem_attr_h );
198+ if (status != UCS_OK ) {
199+ return status ;
200+ }
122201
123- if (is_managed ) {
124- mem_type = UCS_MEMORY_TYPE_CUDA_MANAGED ;
125- } else {
126- mem_type = UCS_MEMORY_TYPE_CUDA ;
127-
128- /* Synchronize for DMA */
129- cu_err = cuPointerSetAttribute (& value ,
130- CU_POINTER_ATTRIBUTE_SYNC_MEMOPS ,
131- (CUdeviceptr )address );
132- if (cu_err != CUDA_SUCCESS ) {
133- cuGetErrorString (cu_err , & cu_err_str );
134- ucs_warn ("cuPointerSetAttribute(%p) error: %s" , address ,
135- cu_err_str );
136- }
202+ if (uct_mem_attr_get_type (mem_attr_h ) == UCS_MEMORY_TYPE_CUDA ) {
203+ unsigned value = 1 ;
204+ CUresult cu_err ;
205+ const char * cu_err_str ;
206+ /* Synchronize for DMA */
207+ cu_err = cuPointerSetAttribute (& value ,
208+ CU_POINTER_ATTRIBUTE_SYNC_MEMOPS ,
209+ (CUdeviceptr )address );
210+ if (cu_err != CUDA_SUCCESS ) {
211+ cuGetErrorString (cu_err , & cu_err_str );
212+ ucs_warn ("cuPointerSetAttribute(%p) error: %s" , address ,
213+ cu_err_str );
137214 }
215+ }
138216
139- if (mem_attr -> field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV ) {
140- status = uct_cuda_base_get_sys_dev (cuda_device , & sys_dev );
141- if (status != UCS_OK ) {
142- return status ;
143- }
144- }
217+ cuda_mem_attr = ucs_derived_of (mem_attr_h , uct_cuda_mem_attr_t );
145218
146- if (mem_attr -> field_mask & (UCT_MD_MEM_ATTR_FIELD_ALLOC_LENGTH |
147- UCT_MD_MEM_ATTR_FIELD_BASE_ADDRESS )) {
148- cu_err = cuMemGetAddressRange ((CUdeviceptr * )& base_address ,
149- & alloc_length , (CUdeviceptr )address );
150- if (cu_err != CUDA_SUCCESS ) {
151- cuGetErrorString (cu_err , & cu_err_str );
152- ucs_error ("ccuMemGetAddressRange(%p) error: %s" , address ,
153- cu_err_str );
154- return UCS_ERR_INVALID_ADDR ;
155- }
156- }
219+ if (mem_attr -> field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV ) {
220+ mem_attr -> sys_dev = cuda_mem_attr -> super .sys_dev ;
157221 }
158222
159223 if (mem_attr -> field_mask & UCT_MD_MEM_ATTR_FIELD_MEM_TYPE ) {
160- mem_attr -> mem_type = mem_type ;
161- }
162-
163- if (mem_attr -> field_mask & UCT_MD_MEM_ATTR_FIELD_SYS_DEV ) {
164- mem_attr -> sys_dev = sys_dev ;
224+ mem_attr -> mem_type = cuda_mem_attr -> super .mem_type ;
165225 }
166226
167227 if (mem_attr -> field_mask & UCT_MD_MEM_ATTR_FIELD_BASE_ADDRESS ) {
168- mem_attr -> base_address = base_address ;
228+ mem_attr -> base_address = cuda_mem_attr -> base_address ;
169229 }
170230
171231 if (mem_attr -> field_mask & UCT_MD_MEM_ATTR_FIELD_ALLOC_LENGTH ) {
172- mem_attr -> alloc_length = alloc_length ;
232+ mem_attr -> alloc_length = cuda_mem_attr -> alloc_length ;
173233 }
174234
235+ uct_mem_attr_destroy (mem_attr_h );
175236 return UCS_OK ;
176237}
177238
@@ -198,3 +259,5 @@ UCS_MODULE_INIT() {
198259 UCS_MODULE_FRAMEWORK_LOAD (uct_cuda , 0 );
199260 return UCS_OK ;
200261}
262+
263+ UCT_MEM_QUERY_REGISTER (uct_cuda_mem_attr_query , UCS_MEMORY_TYPE_CUDA );
0 commit comments