Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jaxlib/gpu/blas_handle_pool.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ template <>
return Handle(pool, handle, stream);
}

} // namespace jax
} // namespace jax
2 changes: 1 addition & 1 deletion jaxlib/gpu/blas_handle_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License.

namespace jax {

using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t>;
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t, BlasTag>;

template <>
absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
Expand Down
36 changes: 24 additions & 12 deletions jaxlib/gpu/handle_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,20 @@ limitations under the License.

namespace jax {

// Tag types for unique pool instantiations
struct DefaultTag {};
struct BlasTag {};
struct SolverTag {};

// To avoid creating cublas/cusolver contexts in the middle of execution, we
// maintain a pool of them.
template <typename HandleType, typename StreamType>

// The Tag template parameter ensures unique pool instantiations for different
// handle types (BLAS, SOLVER, etc.). Without this tag, C++ template
// instantiation would create a single shared static pool when HandleType and
// StreamType are the same, leading to resource contamination between different
// GPU library contexts (e.g., hipBLAS and hipSOLVER sharing the same pool).
template <typename HandleType, typename StreamType, typename Tag = DefaultTag>
class HandlePool {
public:
HandlePool() = default;
Expand Down Expand Up @@ -66,11 +77,11 @@ class HandlePool {
HandleType get() { return handle_; }

private:
friend class HandlePool<HandleType, StreamType>;
Handle(HandlePool<HandleType, StreamType>* pool, HandleType handle,
friend class HandlePool<HandleType, StreamType, Tag>;
Handle(HandlePool<HandleType, StreamType, Tag>* pool, HandleType handle,
StreamType stream)
: pool_(pool), handle_(handle), stream_(stream) {}
HandlePool<HandleType, StreamType>* pool_ = nullptr;
HandlePool<HandleType, StreamType, Tag>* pool_ = nullptr;
HandleType handle_ = nullptr;
StreamType stream_ = nullptr;
};
Expand All @@ -80,31 +91,32 @@ class HandlePool {
static absl::StatusOr<Handle> Borrow(StreamType stream);

private:
static HandlePool<HandleType, StreamType>* Instance();
static HandlePool<HandleType, StreamType, Tag>* Instance();

void Return(HandleType handle, StreamType stream);

absl::Mutex mu_;
std::map<StreamType, std::vector<HandleType>> handles_ ABSL_GUARDED_BY(mu_);
};

template <typename HandleType, typename StreamType>
/*static*/ HandlePool<HandleType, StreamType>*
HandlePool<HandleType, StreamType>::Instance() {
static auto* pool = new HandlePool<HandleType, StreamType>;
template <typename HandleType, typename StreamType, typename Tag>
/*static*/ HandlePool<HandleType, StreamType, Tag>*
HandlePool<HandleType, StreamType, Tag>::Instance() {
static auto* pool = new HandlePool<HandleType, StreamType, Tag>;
return pool;
}

template <typename HandleType, typename StreamType>
void HandlePool<HandleType, StreamType>::Return(HandleType handle,
StreamType stream) {
template <typename HandleType, typename StreamType, typename Tag>
void HandlePool<HandleType, StreamType, Tag>::Return(HandleType handle,
StreamType stream) {
absl::MutexLock lock(&mu_);
handles_[stream].push_back(handle);
}

// template <typename HandleType, typename StreamType>
// HandlePool<HandleType, StreamType>::Borrow(StreamType stream)


} // namespace jax

#endif // JAXLIB_GPU_HANDLE_POOL_H_
3 changes: 1 addition & 2 deletions jaxlib/gpu/solver_handle_pool.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/

#include "jaxlib/gpu/solver_handle_pool.h"

#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
Expand All @@ -40,7 +39,7 @@ template <>
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
Expand Down
4 changes: 2 additions & 2 deletions jaxlib/gpu/solver_handle_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ limitations under the License.

namespace jax {

using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t>;
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t, SolverTag>;

template <>
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
gpuStream_t stream);

#ifdef JAX_GPU_CUDA
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, gpuStream_t>;
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, gpuStream_t, SolverTag>;

template <>
absl::StatusOr<SpSolverHandlePool::Handle> SpSolverHandlePool::Borrow(
Expand Down