1616#include <ucm/util/reloc.h>
1717#include <ucm/util/replace.h>
1818#include <ucm/util/sys.h>
19+ #include <ucm/mem_attr/mem_attr_int.h>
20+ #include <ucs/debug/memtrack.h>
1921#include <ucs/debug/assert.h>
2022#include <ucs/sys/compiler.h>
2123#include <ucs/sys/preprocessor.h>
22- #include <ucs/memory/memory_type.h>
23- #include <ucs/type/status.h>
2424
2525#include <sys/mman.h>
2626#include <pthread.h>
@@ -457,12 +457,35 @@ static void ucm_cudamem_get_existing_alloc(ucm_event_handler_t *handler)
457457 }
458458}
459459
460+ typedef struct ucm_cudamem_attr {
461+ ucm_mem_attr_t super ;
462+ unsigned long long buf_id ;
463+ } ucm_cudamem_attr_t ;
464+
465+ static int ucm_cudamem_attr_cmp (const ucm_mem_attr_h mem_attr1 ,
466+ const ucm_mem_attr_h mem_attr2 )
467+ {
468+ ucm_cudamem_attr_t * cuda_mem_attr1 , * cuda_mem_attr2 ;
469+ cuda_mem_attr1 = ucs_derived_of (mem_attr1 , ucm_cudamem_attr_t );
470+ cuda_mem_attr2 = ucs_derived_of (mem_attr2 , ucm_cudamem_attr_t );
471+ if (cuda_mem_attr1 -> buf_id == cuda_mem_attr2 -> buf_id ) return 0 ;
472+ return 1 ;
473+ }
474+
475+ static void ucm_cudamem_attr_destroy (ucm_mem_attr_h mem_attr )
476+ {
477+ ucm_cudamem_attr_t * cuda_mem_attr ;
478+ cuda_mem_attr = ucs_derived_of (mem_attr ,ucm_cudamem_attr_t );
479+ ucs_free (cuda_mem_attr );
480+ }
481+
460482static ucs_status_t ucm_cudamem_attr_get (const void * address , size_t length ,
461- ucs_memory_attr_t * mem_attr )
483+ ucm_mem_attr_h * mem_attr_p )
462484{
463485#define UCM_CUDA_MEM_QUERY_NUM_ATTRS 3
464486 CUmemorytype cuda_mem_type = (CUmemorytype )0 ;
465487 uint32_t is_managed = 0 ;
488+ unsigned long long buf_id = 0 ;
466489 CUpointer_attribute attr_type [UCM_CUDA_MEM_QUERY_NUM_ATTRS ];
467490 void * attr_data [UCM_CUDA_MEM_QUERY_NUM_ATTRS ];
468491 CUresult cu_err ;
@@ -474,23 +497,31 @@ static ucs_status_t ucm_cudamem_attr_get(const void *address, size_t length,
474497 attr_type [1 ] = CU_POINTER_ATTRIBUTE_IS_MANAGED ;
475498 attr_data [1 ] = & is_managed ;
476499 attr_type [2 ] = CU_POINTER_ATTRIBUTE_BUFFER_ID ;
477- attr_data [2 ] = & mem_attr -> cuda . buf_id ;
500+ attr_data [2 ] = & buf_id ;
478501
479502 cu_err = cuPointerGetAttributes (ucs_static_array_size (attr_data ),
480503 attr_type , attr_data ,
481504 (CUdeviceptr )address );
482505 if (cu_err == CUDA_SUCCESS ) {
506+ ucm_cudamem_attr_t * mem_attr ;
507+ mem_attr = ucs_malloc (sizeof (* mem_attr ),"cudamem_attr" );
508+ if (mem_attr == NULL ) return UCS_ERR_NO_MEMORY ;
483509 switch (cuda_mem_type ) {
484510 case CU_MEMORYTYPE_DEVICE :
485- mem_attr -> mem_type = is_managed ? UCS_MEMORY_TYPE_CUDA_MANAGED
486- : UCS_MEMORY_TYPE_CUDA ;
511+ mem_attr -> super .mem_type = is_managed
512+ ? UCS_MEMORY_TYPE_CUDA_MANAGED
513+ : UCS_MEMORY_TYPE_CUDA ;
487514 break ;
488515 case CU_MEMORYTYPE_HOST :
489- mem_attr -> mem_type = UCS_MEMORY_TYPE_HOST ;
516+ mem_attr -> super . mem_type = UCS_MEMORY_TYPE_HOST ;
490517 break ;
491518 default :
492519 return UCS_ERR_INVALID_ADDR ;
493520 }
521+ mem_attr -> buf_id = buf_id ;
522+ mem_attr -> super .cmp = & ucm_cudamem_attr_cmp ;
523+ mem_attr -> super .destroy = & ucm_cudamem_attr_destroy ;
524+ * mem_attr_p = & mem_attr -> super ;
494525 return UCS_OK ;
495526 }
496527
0 commit comments