Skip to content

Commit 880e654

Browse files
committed
UMTensor: use generic device_data, no need for overload
1 parent 739d264 commit 880e654

File tree

1 file changed

+46
-47
lines changed

1 file changed

+46
-47
lines changed

src/TiledArray/device/um_tensor.h

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
#include <TiledArray/device/um_storage.h>
3838
#include <TiledArray/external/device.h>
3939
#include <TiledArray/external/librett.h>
40-
#include <TiledArray/fwd.h>
4140
#include <TiledArray/math/gemm_helper.h>
4241
#include <TiledArray/platform.h>
4342
#include <TiledArray/range.h>
@@ -69,17 +68,17 @@ void to_host(const UMTensor<T> &tensor) {
6968
stream);
7069
}
7170

72-
/// get device data pointer
73-
template <typename T>
74-
auto *device_data(const UMTensor<T> &tensor) {
75-
return tensor.data();
76-
}
71+
// /// get device data pointer
72+
// template <typename T>
73+
// auto *device_data(const UMTensor<T> &tensor) {
74+
// return tensor.data();
75+
// }
7776

78-
/// get device data pointer (non-const)
79-
template <typename T>
80-
auto *device_data(UMTensor<T> &tensor) {
81-
return tensor.data();
82-
}
77+
// /// get device data pointer (non-const)
78+
// template <typename T>
79+
// auto *device_data(UMTensor<T> &tensor) {
80+
// return tensor.data();
81+
// }
8382

8483
/// handle ComplexConjugate handling for scaling functions
8584
/// follows the logic in device/btas.h
@@ -159,8 +158,8 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
159158

160159
blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(),
161160
gemm_helper.left_op(), n, m, k, factor_t,
162-
detail::device_data(right), ldb, detail::device_data(left), lda,
163-
zero, detail::device_data(result), ldc, queue);
161+
device_data(right), ldb, device_data(left), lda,
162+
zero, device_data(result), ldc, queue);
164163

165164
device::sync_madness_task_with(stream);
166165
return result;
@@ -220,8 +219,8 @@ void gemm(UMTensor<T> &result, const UMTensor<T> &left,
220219

221220
blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(),
222221
gemm_helper.left_op(), n, m, k, factor_t,
223-
detail::device_data(right), ldb, detail::device_data(left), lda,
224-
one, detail::device_data(result), ldc, queue);
222+
device_data(right), ldb, device_data(left), lda,
223+
one, device_data(result), ldc, queue);
225224

226225
device::sync_madness_task_with(stream);
227226
}
@@ -242,8 +241,8 @@ UMTensor<T> clone(const UMTensor<T> &arg) {
242241

243242
// copy data
244243
auto &queue = blasqueue_for(result.range());
245-
blas::copy(result.size(), detail::device_data(arg), 1,
246-
detail::device_data(result), 1, queue);
244+
blas::copy(result.size(), device_data(arg), 1,
245+
device_data(result), 1, queue);
247246
device::sync_madness_task_with(stream);
248247
return result;
249248
}
@@ -270,8 +269,8 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
270269
detail::to_device(result);
271270

272271
// copy data
273-
blas::copy(result.size(), detail::device_data(arg), 1,
274-
detail::device_data(result), 1, queue);
272+
blas::copy(result.size(), device_data(arg), 1,
273+
device_data(result), 1, queue);
275274
device::sync_madness_task_with(stream);
276275
return result;
277276
}
@@ -302,8 +301,8 @@ UMTensor<T> permute(const UMTensor<T> &arg,
302301
detail::to_device(result);
303302

304303
// invoke permute function from librett
305-
librett_permute(const_cast<T *>(detail::device_data(arg)),
306-
detail::device_data(result), arg.range(), perm, stream);
304+
librett_permute(const_cast<T *>(device_data(arg)),
305+
device_data(result), arg.range(), perm, stream);
307306
device::sync_madness_task_with(stream);
308307
return result;
309308
}
@@ -328,7 +327,7 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor) {
328327

329328
auto result = clone(arg);
330329

331-
detail::apply_scale_factor(detail::device_data(result), result.size(), factor,
330+
detail::apply_scale_factor(device_data(result), result.size(), factor,
332331
queue);
333332

334333
device::sync_madness_task_with(stream);
@@ -345,7 +344,7 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
345344

346345
// in-place scale
347346
// ComplexConjugate is handled as in device/btas.h
348-
detail::apply_scale_factor(detail::device_data(arg), arg.size(), factor,
347+
detail::apply_scale_factor(device_data(arg), arg.size(), factor,
349348
queue);
350349

351350
device::sync_madness_task_with(stream);
@@ -398,10 +397,10 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
398397
detail::to_device(result);
399398

400399
// result = arg1 + arg2
401-
blas::copy(result.size(), detail::device_data(arg1), 1,
402-
detail::device_data(result), 1, queue);
403-
blas::axpy(result.size(), 1, detail::device_data(arg2), 1,
404-
detail::device_data(result), 1, queue);
400+
blas::copy(result.size(), device_data(arg1), 1,
401+
device_data(result), 1, queue);
402+
blas::axpy(result.size(), 1, device_data(arg2), 1,
403+
device_data(result), 1, queue);
405404
device::sync_madness_task_with(stream);
406405
return result;
407406
}
@@ -444,8 +443,8 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
444443
detail::to_device(arg);
445444

446445
// result += arg
447-
blas::axpy(result.size(), 1, detail::device_data(arg), 1,
448-
detail::device_data(result), 1, queue);
446+
blas::axpy(result.size(), 1, device_data(arg), 1,
447+
device_data(result), 1, queue);
449448
device::sync_madness_task_with(stream);
450449
return result;
451450
}
@@ -474,10 +473,10 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
474473
detail::to_device(result);
475474

476475
// result = arg1 - arg2
477-
blas::copy(result.size(), detail::device_data(arg1), 1,
478-
detail::device_data(result), 1, queue);
479-
blas::axpy(result.size(), T(-1), detail::device_data(arg2), 1,
480-
detail::device_data(result), 1, queue);
476+
blas::copy(result.size(), device_data(arg1), 1,
477+
device_data(result), 1, queue);
478+
blas::axpy(result.size(), T(-1), device_data(arg2), 1,
479+
device_data(result), 1, queue);
481480
device::sync_madness_task_with(stream);
482481
return result;
483482
}
@@ -520,8 +519,8 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
520519
detail::to_device(arg);
521520

522521
// result -= arg
523-
blas::axpy(result.size(), T(-1), detail::device_data(arg), 1,
524-
detail::device_data(result), 1, queue);
522+
blas::axpy(result.size(), T(-1), device_data(arg), 1,
523+
device_data(result), 1, queue);
525524
device::sync_madness_task_with(stream);
526525
return result;
527526
}
@@ -551,8 +550,8 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
551550
detail::to_device(result);
552551

553552
// element-wise multiplication
554-
device::mult_kernel(detail::device_data(result), detail::device_data(arg1),
555-
detail::device_data(arg2), arg1.size(), stream);
553+
device::mult_kernel(device_data(result), device_data(arg1),
554+
device_data(arg2), arg1.size(), stream);
556555
device::sync_madness_task_with(stream);
557556
return result;
558557
}
@@ -595,7 +594,7 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg) {
595594
detail::to_device(arg);
596595

597596
// in-place element-wise multiplication
598-
device::mult_to_kernel(detail::device_data(result), detail::device_data(arg),
597+
device::mult_to_kernel(device_data(result), device_data(arg),
599598
result.size(), stream);
600599

601600
device::sync_madness_task_with(stream);
@@ -624,8 +623,8 @@ T dot(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
624623

625624
// compute dot product using device BLAS
626625
auto result = T(0);
627-
blas::dot(arg1.size(), detail::device_data(arg1), 1,
628-
detail::device_data(arg2), 1, &result, queue);
626+
blas::dot(arg1.size(), device_data(arg1), 1,
627+
device_data(arg2), 1, &result, queue);
629628
device::sync_madness_task_with(stream);
630629
return result;
631630
}
@@ -643,7 +642,7 @@ T squared_norm(const UMTensor<T> &arg) {
643642

644643
// compute squared norm using dot
645644
auto result = T(0);
646-
blas::dot(arg.size(), detail::device_data(arg), 1, detail::device_data(arg),
645+
blas::dot(arg.size(), device_data(arg), 1, device_data(arg),
647646
1, &result, queue);
648647
device::sync_madness_task_with(stream);
649648
return result;
@@ -659,7 +658,7 @@ T sum(const UMTensor<T> &arg) {
659658
detail::to_device(arg);
660659
auto stream = device::stream_for(arg.range());
661660
auto result =
662-
device::sum_kernel(detail::device_data(arg), arg.size(), stream);
661+
device::sum_kernel(device_data(arg), arg.size(), stream);
663662
device::sync_madness_task_with(stream);
664663
return result;
665664
}
@@ -669,7 +668,7 @@ T product(const UMTensor<T> &arg) {
669668
detail::to_device(arg);
670669
auto stream = device::stream_for(arg.range());
671670
auto result =
672-
device::product_kernel(detail::device_data(arg), arg.size(), stream);
671+
device::product_kernel(device_data(arg), arg.size(), stream);
673672
device::sync_madness_task_with(stream);
674673
return result;
675674
}
@@ -679,7 +678,7 @@ T max(const UMTensor<T> &arg) {
679678
detail::to_device(arg);
680679
auto stream = device::stream_for(arg.range());
681680
auto result =
682-
device::max_kernel(detail::device_data(arg), arg.size(), stream);
681+
device::max_kernel(device_data(arg), arg.size(), stream);
683682
device::sync_madness_task_with(stream);
684683
return result;
685684
}
@@ -689,7 +688,7 @@ T min(const UMTensor<T> &arg) {
689688
detail::to_device(arg);
690689
auto stream = device::stream_for(arg.range());
691690
auto result =
692-
device::min_kernel(detail::device_data(arg), arg.size(), stream);
691+
device::min_kernel(device_data(arg), arg.size(), stream);
693692
device::sync_madness_task_with(stream);
694693
return result;
695694
}
@@ -699,7 +698,7 @@ T abs_max(const UMTensor<T> &arg) {
699698
detail::to_device(arg);
700699
auto stream = device::stream_for(arg.range());
701700
auto result =
702-
device::absmax_kernel(detail::device_data(arg), arg.size(), stream);
701+
device::absmax_kernel(device_data(arg), arg.size(), stream);
703702
device::sync_madness_task_with(stream);
704703
return result;
705704
}
@@ -709,7 +708,7 @@ T abs_min(const UMTensor<T> &arg) {
709708
detail::to_device(arg);
710709
auto stream = device::stream_for(arg.range());
711710
auto result =
712-
device::absmin_kernel(detail::device_data(arg), arg.size(), stream);
711+
device::absmin_kernel(device_data(arg), arg.size(), stream);
713712
device::sync_madness_task_with(stream);
714713

715714
return result;

0 commit comments

Comments
 (0)