Skip to content

Commit 4d49336

Browse files
committed
Add opaque pointer types for BLAS/SOLVER handle pool separation
Addresses handle pool singleton sharing issue between different GPU operation types in ROCm/HIP backend. Changes made: - Add opaque pointer typedefs for HIP backend: * typedef struct hipblasHandle_* gpublasHandle_t * typedef struct hipsolverHandle_* gpusolverDnHandle_t - Implement inline wrapper functions for type-safe handle operations: * gpublasCreate() and gpublasSetStream() for BLAS handles * gpusolverDnCreate() and gpusolverDnSetStream() for SOLVER handles
1 parent 3d6b521 commit 4d49336

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

jaxlib/gpu/vendor.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,12 @@ typedef hipDoubleComplex gpuDoubleComplex;
431431
typedef hipComplex gpublasComplex;
432432
typedef hipDoubleComplex gpublasDoubleComplex;
433433

434-
typedef hipsolverHandle_t gpusolverDnHandle_t;
434+
// Create unique opaque pointer types for proper singleton separation - BLAS and SOLVER only
435+
typedef struct hipblasHandle_* gpublasHandle_t;
436+
typedef struct hipsolverHandle_* gpusolverDnHandle_t;
437+
435438
typedef hipblasFillMode_t gpublasFillMode_t;
436439
typedef hipsolverFillMode_t gpusolverFillMode_t;
437-
typedef hipblasHandle_t gpublasHandle_t;
438440
typedef hipblasOperation_t gpublasOperation_t;
439441
typedef hipblasStatus_t gpublasStatus_t;
440442
typedef hipCtx_t gpuContext_t;
@@ -480,8 +482,15 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
480482
#define GPU_C_64F HIP_C_64F
481483
#define GPU_R_64F HIP_R_64F
482484

483-
#define gpublasCreate hipblasCreate
484-
#define gpublasSetStream hipblasSetStream
485+
// Inline wrapper functions for BLAS handles to ensure unique types
486+
inline hipblasStatus_t gpublasCreate(gpublasHandle_t* handle) {
487+
return hipblasCreate(reinterpret_cast<hipblasHandle_t*>(handle));
488+
}
489+
490+
inline hipblasStatus_t gpublasSetStream(gpublasHandle_t handle, hipStream_t stream) {
491+
return hipblasSetStream(reinterpret_cast<hipblasHandle_t>(handle), stream);
492+
}
493+
485494
#define gpublasSgeqrfBatched hipblasSgeqrfBatched
486495
#define gpublasDgeqrfBatched hipblasDgeqrfBatched
487496
#define gpublasCgeqrfBatched hipblasCgeqrfBatched
@@ -531,8 +540,15 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
531540
#define GPUDNN_LSTM miopenLSTM
532541
#define GPUDNN_BIDIRECTIONAL miopenRNNbidirection
533542

534-
#define gpusolverDnCreate hipsolverCreate
535-
#define gpusolverDnSetStream hipsolverSetStream
543+
// Inline wrapper functions for SOLVER handles to ensure unique types
544+
inline hipsolverStatus_t gpusolverDnCreate(gpusolverDnHandle_t* handle) {
545+
return hipsolverCreate(reinterpret_cast<hipsolverHandle_t*>(handle));
546+
}
547+
548+
inline hipsolverStatus_t gpusolverDnSetStream(gpusolverDnHandle_t handle, hipStream_t stream) {
549+
return hipsolverSetStream(reinterpret_cast<hipsolverHandle_t>(handle), stream);
550+
}
551+
536552
#define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo
537553
#define gpusolverDnDestroySyevjInfo hipsolverDestroySyevjInfo
538554
#define gpusolverDnSgeqrf hipsolverSgeqrf

0 commit comments

Comments
 (0)