1818#include  <cuda.h> 
1919
2020
21+ typedef  struct  uct_cuda_mem_attr  {
22+     uct_mem_attr_t  super ;
23+     unsigned long long  buf_id ;
24+ } uct_cuda_mem_attr_t ;
25+ 
26+ static  int  uct_cuda_mem_attr_cmp (const  uct_mem_attr_h  mem_attr1 ,
27+                                  const  uct_mem_attr_h  mem_attr2 )
28+ {
29+     uct_cuda_mem_attr_t  * cuda_mem_attr1 , * cuda_mem_attr2 ;
30+     cuda_mem_attr1  =  ucs_derived_of (mem_attr1 , uct_cuda_mem_attr_t );
31+     cuda_mem_attr2  =  ucs_derived_of (mem_attr2 , uct_cuda_mem_attr_t );
32+     return  cuda_mem_attr1 -> buf_id  ==  cuda_mem_attr2 -> buf_id  ? 0  : 1 ;
33+ }
34+ 
35+ static  void  uct_cuda_mem_attr_destroy (uct_mem_attr_h  mem_attr )
36+ {
37+     uct_cuda_mem_attr_t  * cuda_mem_attr ;
38+     cuda_mem_attr  =  ucs_derived_of (mem_attr , uct_cuda_mem_attr_t );
39+     ucs_free (cuda_mem_attr );
40+ }
41+ 
42+ UCS_PROFILE_FUNC (ucs_status_t , uct_cuda_mem_attr_query ,
43+                  (address , length , mem_attr_p ),
44+                  const  void  * address , size_t  length ,
45+                  uct_mem_attr_h  * mem_attr_p )
46+ {
47+ #define  UCT_CUDA_MEM_QUERY_NUM_ATTRS  4
48+     CUmemorytype  cuda_mem_mype  =  (CUmemorytype )0 ;
49+     unsigned long long  buf_id   =  0 ;
50+     uint32_t  is_managed         =  0 ;
51+     CUdevice  cuda_device        =  -1 ;
52+     CUpointer_attribute  attr_type [UCT_CUDA_MEM_QUERY_NUM_ATTRS ];
53+     void  * attr_data [UCT_CUDA_MEM_QUERY_NUM_ATTRS ];
54+     ucs_memory_type_t  mem_type ;
55+     ucs_sys_device_t  sys_dev ;
56+     ucs_status_t  status ;
57+     CUresult  cu_err ;
58+     uct_cuda_mem_attr_t  * cuda_mem_attr ;
59+ 
60+     if  (address  ==  NULL ) {
61+         return  UCS_ERR_INVALID_ADDR ;
62+     }
63+ 
64+     attr_type [0 ] =  CU_POINTER_ATTRIBUTE_MEMORY_TYPE ;
65+     attr_data [0 ] =  & cuda_mem_mype ;
66+     attr_type [1 ] =  CU_POINTER_ATTRIBUTE_IS_MANAGED ;
67+     attr_data [1 ] =  & is_managed ;
68+     attr_type [2 ] =  CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL ;
69+     attr_data [2 ] =  & cuda_device ;
70+     attr_type [3 ] =  CU_POINTER_ATTRIBUTE_BUFFER_ID ;
71+     attr_data [3 ] =  & buf_id ;
72+ 
73+     cu_err  =  cuPointerGetAttributes (ucs_static_array_size (attr_data ),
74+                                     attr_type , attr_data ,
75+                                     (CUdeviceptr )address );
76+     if  ((cu_err  !=  CUDA_SUCCESS ) ||  (cuda_mem_mype  !=  CU_MEMORYTYPE_DEVICE )) {
77+         /* pointer not recognized */ 
78+         return  UCS_ERR_INVALID_ADDR ;
79+     }
80+ 
81+     if  (is_managed ) {
82+         mem_type  =  UCS_MEMORY_TYPE_CUDA_MANAGED ;
83+     } else  {
84+         mem_type  =  UCS_MEMORY_TYPE_CUDA ;
85+     }
86+ 
87+     status  =  uct_cuda_base_get_sys_dev (cuda_device , & sys_dev );
88+     if  (status  !=  UCS_OK ) {
89+         return  status ;
90+     }
91+ 
92+     cuda_mem_attr  =  ucs_malloc (sizeof (* cuda_mem_attr ), "cuda_mem_attr" );
93+     if  (cuda_mem_attr  ==  NULL ) {
94+         return  UCS_ERR_NO_MEMORY ;
95+     }
96+     cuda_mem_attr -> buf_id          =  buf_id ;
97+     cuda_mem_attr -> super .mem_type  =  mem_type ;
98+     cuda_mem_attr -> super .sys_dev   =  sys_dev ;
99+     cuda_mem_attr -> super .cmp       =  uct_cuda_mem_attr_cmp ;
100+     cuda_mem_attr -> super .destroy   =  uct_cuda_mem_attr_destroy ;
101+ 
102+     * mem_attr_p  =  & cuda_mem_attr -> super ;
103+     return  UCS_OK ;
104+ }
105+ 
21106ucs_status_t  uct_cuda_base_get_sys_dev (CUdevice  cuda_device ,
22107                                       ucs_sys_device_t  * sys_dev_p )
23108{
@@ -61,7 +146,8 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_detect_memory_type,
61146                 uct_md_h  md , const  void  * address , size_t  length ,
62147                 ucs_memory_type_t  * mem_type_p )
63148{
64-     uct_md_mem_attr_t  mem_attr ;
149+     /* self-initializing to suppress wrong maybe-uninitialized error */ 
150+     uct_md_mem_attr_t  mem_attr  =  mem_attr ;
65151    ucs_status_t  status ;
66152
67153    mem_attr .field_mask  =  UCT_MD_MEM_ATTR_FIELD_MEM_TYPE ;
@@ -80,72 +166,43 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_base_mem_query,
80166                 uct_md_h  md , const  void  * address , size_t  length ,
81167                 uct_md_mem_attr_t  * mem_attr )
82168{
83- #define  UCT_CUDA_MEM_QUERY_NUM_ATTRS  3
84-     CUmemorytype  cuda_mem_mype  =  (CUmemorytype )0 ;
85-     uint32_t  is_managed         =  0 ;
86-     unsigned  value              =  1 ;
87-     CUdevice  cuda_device        =  -1 ;
88-     CUpointer_attribute  attr_type [UCT_CUDA_MEM_QUERY_NUM_ATTRS ];
89-     void  * attr_data [UCT_CUDA_MEM_QUERY_NUM_ATTRS ];
90-     ucs_memory_type_t  mem_type ;
91-     const  char  * cu_err_str ;
169+     uct_mem_attr_h  mem_attr_h ;
92170    ucs_status_t  status ;
93-     CUresult  cu_err ;
94171
95172    if  (!(mem_attr -> field_mask  &  (UCT_MD_MEM_ATTR_FIELD_MEM_TYPE  |
96173                                  UCT_MD_MEM_ATTR_FIELD_SYS_DEV ))) {
97174        return  UCS_OK ;
98175    }
99176
100-     if  (address  ==  NULL ) {
101-         mem_type               =  UCS_MEMORY_TYPE_HOST ;
102-         if  (mem_attr -> field_mask  &  UCT_MD_MEM_ATTR_FIELD_SYS_DEV ) {
103-             mem_attr -> sys_dev  =  UCS_SYS_DEVICE_ID_UNKNOWN ;
104-         }
105-     } else  {
106-         attr_type [0 ] =  CU_POINTER_ATTRIBUTE_MEMORY_TYPE ;
107-         attr_data [0 ] =  & cuda_mem_mype ;
108-         attr_type [1 ] =  CU_POINTER_ATTRIBUTE_IS_MANAGED ;
109-         attr_data [1 ] =  & is_managed ;
110-         attr_type [2 ] =  CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL ;
111-         attr_data [2 ] =  & cuda_device ;
112- 
113-         cu_err  =  cuPointerGetAttributes (ucs_static_array_size (attr_data ),
114-                                         attr_type , attr_data ,
115-                                         (CUdeviceptr )address );
116-         if  ((cu_err  !=  CUDA_SUCCESS ) ||  (cuda_mem_mype  !=  CU_MEMORYTYPE_DEVICE )) {
117-             /* pointer not recognized */ 
118-             return  UCS_ERR_INVALID_ADDR ;
119-         }
177+     status  =  uct_cuda_mem_attr_query (address , length , & mem_attr_h );
178+     if  (status  !=  UCS_OK ) {
179+         return  status ;
180+     }
120181
121-         if  (is_managed ) {
122-             mem_type  =  UCS_MEMORY_TYPE_CUDA_MANAGED ;
123-         } else  {
124-             mem_type  =  UCS_MEMORY_TYPE_CUDA ;
125- 
126-             /* Synchronize for DMA */ 
127-             cu_err  =  cuPointerSetAttribute (& value ,
128-                                            CU_POINTER_ATTRIBUTE_SYNC_MEMOPS ,
129-                                            (CUdeviceptr )address );
130-             if  (cu_err  !=  CUDA_SUCCESS ) {
131-                 cuGetErrorString (cu_err , & cu_err_str );
132-                 ucs_warn ("cuPointerSetAttribute(%p) error: %s" , address ,
133-                          cu_err_str );
134-             }
182+     if  (uct_mem_attr_get_type (mem_attr_h ) ==  UCS_MEMORY_TYPE_CUDA ) {
183+         unsigned  value  =  1 ;
184+         CUresult  cu_err ;
185+         const  char  * cu_err_str ;
186+         /* Synchronize for DMA */ 
187+         cu_err  =  cuPointerSetAttribute (& value ,
188+                                        CU_POINTER_ATTRIBUTE_SYNC_MEMOPS ,
189+                                        (CUdeviceptr )address );
190+         if  (cu_err  !=  CUDA_SUCCESS ) {
191+             cuGetErrorString (cu_err , & cu_err_str );
192+             ucs_warn ("cuPointerSetAttribute(%p) error: %s" , address ,
193+                      cu_err_str );
135194        }
195+     }
136196
137-         if  (mem_attr -> field_mask  &  UCT_MD_MEM_ATTR_FIELD_SYS_DEV ) {
138-             status  =  uct_cuda_base_get_sys_dev (cuda_device , & mem_attr -> sys_dev );
139-             if  (status  !=  UCS_OK ) {
140-                 return  status ;
141-             }
142-         }
197+     if  (mem_attr -> field_mask  &  UCT_MD_MEM_ATTR_FIELD_SYS_DEV ) {
198+         mem_attr -> sys_dev  =  uct_mem_attr_get_sys_dev (mem_attr_h );
143199    }
144200
145201    if  (mem_attr -> field_mask  &  UCT_MD_MEM_ATTR_FIELD_MEM_TYPE ) {
146-         mem_attr -> mem_type  =  mem_type ;
202+         mem_attr -> mem_type  =  uct_mem_attr_get_type ( mem_attr_h ) ;
147203    }
148204
205+     uct_mem_attr_destroy (mem_attr_h );
149206    return  UCS_OK ;
150207}
151208
@@ -172,3 +229,5 @@ UCS_MODULE_INIT() {
172229    UCS_MODULE_FRAMEWORK_LOAD (uct_cuda , 0 );
173230    return  UCS_OK ;
174231}
232+ 
233+ UCT_MEM_QUERY_REGISTER (uct_cuda_mem_attr_query , UCS_MEMORY_TYPE_CUDA );
0 commit comments