3737#include < TiledArray/device/um_storage.h>
3838#include < TiledArray/external/device.h>
3939#include < TiledArray/external/librett.h>
40- #include < TiledArray/fwd.h>
4140#include < TiledArray/math/gemm_helper.h>
4241#include < TiledArray/platform.h>
4342#include < TiledArray/range.h>
@@ -69,17 +68,17 @@ void to_host(const UMTensor<T> &tensor) {
6968 stream);
7069}
7170
72- // / get device data pointer
73- template <typename T>
74- auto *device_data (const UMTensor<T> &tensor) {
75- return tensor.data ();
76- }
71+ // // / get device data pointer
72+ // template <typename T>
73+ // auto *device_data(const UMTensor<T> &tensor) {
74+ // return tensor.data();
75+ // }
7776
78- // / get device data pointer (non-const)
79- template <typename T>
80- auto *device_data (UMTensor<T> &tensor) {
81- return tensor.data ();
82- }
77+ // // / get device data pointer (non-const)
78+ // template <typename T>
79+ // auto *device_data(UMTensor<T> &tensor) {
80+ // return tensor.data();
81+ // }
8382
8483// / handle ComplexConjugate handling for scaling functions
8584// / follows the logic in device/btas.h
@@ -159,8 +158,8 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
159158
160159 blas::gemm (blas::Layout::ColMajor, gemm_helper.right_op (),
161160 gemm_helper.left_op (), n, m, k, factor_t ,
162- detail:: device_data (right), ldb, detail:: device_data (left), lda,
163- zero, detail:: device_data (result), ldc, queue);
161+ device_data (right), ldb, device_data (left), lda,
162+ zero, device_data (result), ldc, queue);
164163
165164 device::sync_madness_task_with (stream);
166165 return result;
@@ -220,8 +219,8 @@ void gemm(UMTensor<T> &result, const UMTensor<T> &left,
220219
221220 blas::gemm (blas::Layout::ColMajor, gemm_helper.right_op (),
222221 gemm_helper.left_op (), n, m, k, factor_t ,
223- detail:: device_data (right), ldb, detail:: device_data (left), lda,
224- one, detail:: device_data (result), ldc, queue);
222+ device_data (right), ldb, device_data (left), lda,
223+ one, device_data (result), ldc, queue);
225224
226225 device::sync_madness_task_with (stream);
227226}
@@ -242,8 +241,8 @@ UMTensor<T> clone(const UMTensor<T> &arg) {
242241
243242 // copy data
244243 auto &queue = blasqueue_for (result.range ());
245- blas::copy (result.size (), detail:: device_data (arg), 1 ,
246- detail:: device_data (result), 1 , queue);
244+ blas::copy (result.size (), device_data (arg), 1 ,
245+ device_data (result), 1 , queue);
247246 device::sync_madness_task_with (stream);
248247 return result;
249248}
@@ -270,8 +269,8 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
270269 detail::to_device (result);
271270
272271 // copy data
273- blas::copy (result.size (), detail:: device_data (arg), 1 ,
274- detail:: device_data (result), 1 , queue);
272+ blas::copy (result.size (), device_data (arg), 1 ,
273+ device_data (result), 1 , queue);
275274 device::sync_madness_task_with (stream);
276275 return result;
277276}
@@ -302,8 +301,8 @@ UMTensor<T> permute(const UMTensor<T> &arg,
302301 detail::to_device (result);
303302
304303 // invoke permute function from librett
305- librett_permute (const_cast <T *>(detail:: device_data (arg)),
306- detail:: device_data (result), arg.range (), perm, stream);
304+ librett_permute (const_cast <T *>(device_data (arg)),
305+ device_data (result), arg.range (), perm, stream);
307306 device::sync_madness_task_with (stream);
308307 return result;
309308}
@@ -328,7 +327,7 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor) {
328327
329328 auto result = clone (arg);
330329
331- detail::apply_scale_factor (detail:: device_data (result), result.size (), factor,
330+ detail::apply_scale_factor (device_data (result), result.size (), factor,
332331 queue);
333332
334333 device::sync_madness_task_with (stream);
@@ -345,7 +344,7 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
345344
346345 // in-place scale
347346 // ComplexConjugate is handled as in device/btas.h
348- detail::apply_scale_factor (detail:: device_data (arg), arg.size (), factor,
347+ detail::apply_scale_factor (device_data (arg), arg.size (), factor,
349348 queue);
350349
351350 device::sync_madness_task_with (stream);
@@ -398,10 +397,10 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
398397 detail::to_device (result);
399398
400399 // result = arg1 + arg2
401- blas::copy (result.size (), detail:: device_data (arg1), 1 ,
402- detail:: device_data (result), 1 , queue);
403- blas::axpy (result.size (), 1 , detail:: device_data (arg2), 1 ,
404- detail:: device_data (result), 1 , queue);
400+ blas::copy (result.size (), device_data (arg1), 1 ,
401+ device_data (result), 1 , queue);
402+ blas::axpy (result.size (), 1 , device_data (arg2), 1 ,
403+ device_data (result), 1 , queue);
405404 device::sync_madness_task_with (stream);
406405 return result;
407406}
@@ -444,8 +443,8 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
444443 detail::to_device (arg);
445444
446445 // result += arg
447- blas::axpy (result.size (), 1 , detail:: device_data (arg), 1 ,
448- detail:: device_data (result), 1 , queue);
446+ blas::axpy (result.size (), 1 , device_data (arg), 1 ,
447+ device_data (result), 1 , queue);
449448 device::sync_madness_task_with (stream);
450449 return result;
451450}
@@ -474,10 +473,10 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
474473 detail::to_device (result);
475474
476475 // result = arg1 - arg2
477- blas::copy (result.size (), detail:: device_data (arg1), 1 ,
478- detail:: device_data (result), 1 , queue);
479- blas::axpy (result.size (), T (-1 ), detail:: device_data (arg2), 1 ,
480- detail:: device_data (result), 1 , queue);
476+ blas::copy (result.size (), device_data (arg1), 1 ,
477+ device_data (result), 1 , queue);
478+ blas::axpy (result.size (), T (-1 ), device_data (arg2), 1 ,
479+ device_data (result), 1 , queue);
481480 device::sync_madness_task_with (stream);
482481 return result;
483482}
@@ -520,8 +519,8 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
520519 detail::to_device (arg);
521520
522521 // result -= arg
523- blas::axpy (result.size (), T (-1 ), detail:: device_data (arg), 1 ,
524- detail:: device_data (result), 1 , queue);
522+ blas::axpy (result.size (), T (-1 ), device_data (arg), 1 ,
523+ device_data (result), 1 , queue);
525524 device::sync_madness_task_with (stream);
526525 return result;
527526}
@@ -551,8 +550,8 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
551550 detail::to_device (result);
552551
553552 // element-wise multiplication
554- device::mult_kernel (detail:: device_data (result), detail:: device_data (arg1),
555- detail:: device_data (arg2), arg1.size (), stream);
553+ device::mult_kernel (device_data (result), device_data (arg1),
554+ device_data (arg2), arg1.size (), stream);
556555 device::sync_madness_task_with (stream);
557556 return result;
558557}
@@ -595,7 +594,7 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg) {
595594 detail::to_device (arg);
596595
597596 // in-place element-wise multiplication
598- device::mult_to_kernel (detail:: device_data (result), detail:: device_data (arg),
597+ device::mult_to_kernel (device_data (result), device_data (arg),
599598 result.size (), stream);
600599
601600 device::sync_madness_task_with (stream);
@@ -624,8 +623,8 @@ T dot(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
624623
625624 // compute dot product using device BLAS
626625 auto result = T (0 );
627- blas::dot (arg1.size (), detail:: device_data (arg1), 1 ,
628- detail:: device_data (arg2), 1 , &result, queue);
626+ blas::dot (arg1.size (), device_data (arg1), 1 ,
627+ device_data (arg2), 1 , &result, queue);
629628 device::sync_madness_task_with (stream);
630629 return result;
631630}
@@ -643,7 +642,7 @@ T squared_norm(const UMTensor<T> &arg) {
643642
644643 // compute squared norm using dot
645644 auto result = T (0 );
646- blas::dot (arg.size (), detail:: device_data (arg), 1 , detail:: device_data (arg),
645+ blas::dot (arg.size (), device_data (arg), 1 , device_data (arg),
647646 1 , &result, queue);
648647 device::sync_madness_task_with (stream);
649648 return result;
@@ -659,7 +658,7 @@ T sum(const UMTensor<T> &arg) {
659658 detail::to_device (arg);
660659 auto stream = device::stream_for (arg.range ());
661660 auto result =
662- device::sum_kernel (detail:: device_data (arg), arg.size (), stream);
661+ device::sum_kernel (device_data (arg), arg.size (), stream);
663662 device::sync_madness_task_with (stream);
664663 return result;
665664}
@@ -669,7 +668,7 @@ T product(const UMTensor<T> &arg) {
669668 detail::to_device (arg);
670669 auto stream = device::stream_for (arg.range ());
671670 auto result =
672- device::product_kernel (detail:: device_data (arg), arg.size (), stream);
671+ device::product_kernel (device_data (arg), arg.size (), stream);
673672 device::sync_madness_task_with (stream);
674673 return result;
675674}
@@ -679,7 +678,7 @@ T max(const UMTensor<T> &arg) {
679678 detail::to_device (arg);
680679 auto stream = device::stream_for (arg.range ());
681680 auto result =
682- device::max_kernel (detail:: device_data (arg), arg.size (), stream);
681+ device::max_kernel (device_data (arg), arg.size (), stream);
683682 device::sync_madness_task_with (stream);
684683 return result;
685684}
@@ -689,7 +688,7 @@ T min(const UMTensor<T> &arg) {
689688 detail::to_device (arg);
690689 auto stream = device::stream_for (arg.range ());
691690 auto result =
692- device::min_kernel (detail:: device_data (arg), arg.size (), stream);
691+ device::min_kernel (device_data (arg), arg.size (), stream);
693692 device::sync_madness_task_with (stream);
694693 return result;
695694}
@@ -699,7 +698,7 @@ T abs_max(const UMTensor<T> &arg) {
699698 detail::to_device (arg);
700699 auto stream = device::stream_for (arg.range ());
701700 auto result =
702- device::absmax_kernel (detail:: device_data (arg), arg.size (), stream);
701+ device::absmax_kernel (device_data (arg), arg.size (), stream);
703702 device::sync_madness_task_with (stream);
704703 return result;
705704}
@@ -709,7 +708,7 @@ T abs_min(const UMTensor<T> &arg) {
709708 detail::to_device (arg);
710709 auto stream = device::stream_for (arg.range ());
711710 auto result =
712- device::absmin_kernel (detail:: device_data (arg), arg.size (), stream);
711+ device::absmin_kernel (device_data (arg), arg.size (), stream);
713712 device::sync_madness_task_with (stream);
714713
715714 return result;
0 commit comments