Skip to content

Commit af551f1

Browse files
committed
UMTensor: restrict kernels to tensor of scalars
1 parent 880e654 commit af551f1

File tree

1 file changed

+45
-17
lines changed

1 file changed

+45
-17
lines changed

src/TiledArray/device/um_tensor.h

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ namespace detail {
4848

4949
/// is_device_tile specialization for UMTensor
5050
template <typename T>
51+
requires TiledArray::detail::is_numeric_v<T>
5152
struct is_device_tile<
5253
::TiledArray::Tensor<T, TiledArray::device_um_allocator<T>>>
5354
: public std::true_type {};
5455

5556
/// pre-fetch to device
5657
template <typename T>
58+
requires TiledArray::detail::is_numeric_v<T>
5759
void to_device(const UMTensor<T> &tensor) {
5860
auto stream = device::stream_for(tensor.range());
5961
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(tensor,
@@ -62,6 +64,7 @@ void to_device(const UMTensor<T> &tensor) {
6264

6365
/// pre-fetch to host
6466
template <typename T>
67+
requires TiledArray::detail::is_numeric_v<T>
6568
void to_host(const UMTensor<T> &tensor) {
6669
auto stream = device::stream_for(tensor.range());
6770
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(tensor,
@@ -83,6 +86,7 @@ void to_host(const UMTensor<T> &tensor) {
8386
/// handle ComplexConjugate handling for scaling functions
8487
/// follows the logic in device/btas.h
8588
template <typename T, typename Scalar, typename Queue>
89+
requires TiledArray::detail::is_numeric_v<T>
8690
void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
8791
Queue &queue) {
8892
if constexpr (TiledArray::detail::is_blas_numeric_v<Scalar> ||
@@ -111,7 +115,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
111115
///
112116

113117
template <typename T, typename Scalar>
114-
requires TiledArray::detail::is_numeric_v<Scalar>
118+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
115119
UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
116120
Scalar factor,
117121
const TiledArray::math::GemmHelper &gemm_helper) {
@@ -166,7 +170,7 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
166170
}
167171

168172
template <typename T, typename Scalar>
169-
requires TiledArray::detail::is_numeric_v<Scalar>
173+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
170174
void gemm(UMTensor<T> &result, const UMTensor<T> &left,
171175
const UMTensor<T> &right, Scalar factor,
172176
const TiledArray::math::GemmHelper &gemm_helper) {
@@ -230,6 +234,7 @@ void gemm(UMTensor<T> &result, const UMTensor<T> &left,
230234
///
231235

232236
template <typename T>
237+
requires TiledArray::detail::is_numeric_v<T>
233238
UMTensor<T> clone(const UMTensor<T> &arg) {
234239
TA_ASSERT(!arg.empty());
235240

@@ -252,6 +257,7 @@ UMTensor<T> clone(const UMTensor<T> &arg) {
252257
///
253258

254259
template <typename T, typename Index>
260+
requires TiledArray::detail::is_numeric_v<T>
255261
UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
256262
TA_ASSERT(!arg.empty());
257263

@@ -276,6 +282,7 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
276282
}
277283

278284
template <typename T, typename Index>
285+
requires TiledArray::detail::is_numeric_v<T>
279286
UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
280287
const_cast<TiledArray::Range &>(arg.range()).inplace_shift(bound_shift);
281288
return arg;
@@ -286,6 +293,7 @@ UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
286293
///
287294

288295
template <typename T>
296+
requires TiledArray::detail::is_numeric_v<T>
289297
UMTensor<T> permute(const UMTensor<T> &arg,
290298
const TiledArray::Permutation &perm) {
291299
TA_ASSERT(!arg.empty());
@@ -308,6 +316,7 @@ UMTensor<T> permute(const UMTensor<T> &arg,
308316
}
309317

310318
template <typename T>
319+
requires TiledArray::detail::is_numeric_v<T>
311320
UMTensor<T> permute(const UMTensor<T> &arg,
312321
const TiledArray::BipartitePermutation &perm) {
313322
TA_ASSERT(!arg.empty());
@@ -320,7 +329,7 @@ UMTensor<T> permute(const UMTensor<T> &arg,
320329
///
321330

322331
template <typename T, typename Scalar>
323-
requires TiledArray::detail::is_numeric_v<Scalar>
332+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
324333
UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor) {
325334
auto &queue = blasqueue_for(arg.range());
326335
const auto stream = device::Stream(queue.device(), queue.stream());
@@ -335,7 +344,7 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor) {
335344
}
336345

337346
template <typename T, typename Scalar>
338-
requires TiledArray::detail::is_numeric_v<Scalar>
347+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
339348
UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
340349
auto &queue = blasqueue_for(arg.range());
341350
const auto stream = device::Stream(queue.device(), queue.stream());
@@ -352,7 +361,7 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
352361
}
353362

354363
template <typename T, typename Scalar, typename Perm>
355-
requires TiledArray::detail::is_numeric_v<Scalar> &&
364+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> &&
356365
TiledArray::detail::is_permutation_v<Perm>
357366
UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor,
358367
const Perm &perm) {
@@ -365,18 +374,20 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor,
365374
///
366375

367376
template <typename T>
377+
requires TiledArray::detail::is_numeric_v<T>
368378
UMTensor<T> neg(const UMTensor<T> &arg) {
369379
return scale(arg, T(-1.0));
370380
}
371381

372382
template <typename T, typename Perm>
373-
requires TiledArray::detail::is_permutation_v<Perm>
383+
requires TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T>
374384
UMTensor<T> neg(const UMTensor<T> &arg, const Perm &perm) {
375385
auto result = neg(arg);
376386
return permute(result, perm);
377387
}
378388

379389
template <typename T>
390+
requires TiledArray::detail::is_numeric_v<T>
380391
UMTensor<T> &neg_to(UMTensor<T> &arg) {
381392
return scale_to(arg, T(-1.0));
382393
}
@@ -386,6 +397,7 @@ UMTensor<T> &neg_to(UMTensor<T> &arg) {
386397
///
387398

388399
template <typename T>
400+
requires TiledArray::detail::is_numeric_v<T>
389401
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
390402
UMTensor<T> result(arg1.range());
391403

@@ -406,23 +418,23 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
406418
}
407419

408420
template <typename T, typename Scalar>
409-
requires TiledArray::detail::is_numeric_v<Scalar>
421+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
410422
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
411423
const Scalar factor) {
412424
auto result = add(arg1, arg2);
413425
return scale_to(result, factor);
414426
}
415427

416428
template <typename T, typename Perm>
417-
requires TiledArray::detail::is_permutation_v<Perm>
429+
requires TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T>
418430
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
419431
const Perm &perm) {
420432
auto result = add(arg1, arg2);
421433
return permute(result, perm);
422434
}
423435

424436
template <typename T, typename Scalar, typename Perm>
425-
requires TiledArray::detail::is_numeric_v<Scalar> &&
437+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> &&
426438
TiledArray::detail::is_permutation_v<Perm>
427439
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
428440
const Scalar factor, const Perm &perm) {
@@ -435,6 +447,7 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
435447
///
436448

437449
template <typename T>
450+
requires TiledArray::detail::is_numeric_v<T>
438451
UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
439452
auto &queue = blasqueue_for(result.range());
440453
const auto stream = device::Stream(queue.device(), queue.stream());
@@ -450,7 +463,7 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
450463
}
451464

452465
template <typename T, typename Scalar>
453-
requires TiledArray::detail::is_numeric_v<Scalar>
466+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
454467
UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg,
455468
const Scalar factor) {
456469
add_to(result, arg);
@@ -462,6 +475,7 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg,
462475
///
463476

464477
template <typename T>
478+
requires TiledArray::detail::is_numeric_v<T>
465479
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
466480
UMTensor<T> result(arg1.range());
467481

@@ -482,23 +496,23 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
482496
}
483497

484498
template <typename T, typename Scalar>
485-
requires TiledArray::detail::is_numeric_v<Scalar>
499+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
486500
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
487501
const Scalar factor) {
488502
auto result = subt(arg1, arg2);
489503
return scale_to(result, factor);
490504
}
491505

492506
template <typename T, typename Perm>
493-
requires TiledArray::detail::is_permutation_v<Perm>
507+
requires TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T>
494508
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
495509
const Perm &perm) {
496510
auto result = subt(arg1, arg2);
497511
return permute(result, perm);
498512
}
499513

500514
template <typename T, typename Scalar, typename Perm>
501-
requires TiledArray::detail::is_numeric_v<Scalar> &&
515+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> &&
502516
TiledArray::detail::is_permutation_v<Perm>
503517
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
504518
const Scalar factor, const Perm &perm) {
@@ -511,6 +525,7 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
511525
///
512526

513527
template <typename T>
528+
requires TiledArray::detail::is_numeric_v<T>
514529
UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
515530
auto &queue = blasqueue_for(result.range());
516531
const auto stream = device::Stream(queue.device(), queue.stream());
@@ -526,7 +541,7 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
526541
}
527542

528543
template <typename T, typename Scalar>
529-
requires TiledArray::detail::is_numeric_v<Scalar>
544+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
530545
UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg,
531546
const Scalar factor) {
532547
subt_to(result, arg);
@@ -538,6 +553,7 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg,
538553
///
539554

540555
template <typename T>
556+
requires TiledArray::detail::is_numeric_v<T>
541557
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
542558
TA_ASSERT(arg1.size() == arg2.size());
543559

@@ -557,23 +573,23 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
557573
}
558574

559575
template <typename T, typename Scalar>
560-
requires TiledArray::detail::is_numeric_v<Scalar>
576+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
561577
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
562578
const Scalar factor) {
563579
auto result = mult(arg1, arg2);
564580
return scale_to(result, factor);
565581
}
566582

567583
template <typename T, typename Perm>
568-
requires TiledArray::detail::is_permutation_v<Perm>
584+
requires TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T>
569585
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
570586
const Perm &perm) {
571587
auto result = mult(arg1, arg2);
572588
return permute(result, perm);
573589
}
574590

575591
template <typename T, typename Scalar, typename Perm>
576-
requires TiledArray::detail::is_numeric_v<Scalar> &&
592+
requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> &&
577593
TiledArray::detail::is_permutation_v<Perm>
578594
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
579595
const Scalar factor, const Perm &perm) {
@@ -586,6 +602,7 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
586602
///
587603

588604
template <typename T>
605+
requires TiledArray::detail::is_numeric_v<T>
589606
UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg) {
590607
auto stream = device::stream_for(result.range());
591608
TA_ASSERT(result.size() == arg.size());
@@ -614,6 +631,7 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg,
614631
///
615632

616633
template <typename T>
634+
requires TiledArray::detail::is_numeric_v<T>
617635
T dot(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
618636
auto &queue = blasqueue_for(arg1.range());
619637
const auto stream = device::Stream(queue.device(), queue.stream());
@@ -634,6 +652,7 @@ T dot(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
634652
///
635653

636654
template <typename T>
655+
requires TiledArray::detail::is_numeric_v<T>
637656
T squared_norm(const UMTensor<T> &arg) {
638657
auto &queue = blasqueue_for(arg.range());
639658
const auto stream = device::Stream(queue.device(), queue.stream());
@@ -649,11 +668,13 @@ T squared_norm(const UMTensor<T> &arg) {
649668
}
650669

651670
template <typename T>
671+
requires TiledArray::detail::is_numeric_v<T>
652672
T norm(const UMTensor<T> &arg) {
653673
return std::sqrt(squared_norm(arg));
654674
}
655675

656676
template <typename T>
677+
requires TiledArray::detail::is_numeric_v<T>
657678
T sum(const UMTensor<T> &arg) {
658679
detail::to_device(arg);
659680
auto stream = device::stream_for(arg.range());
@@ -664,6 +685,7 @@ T sum(const UMTensor<T> &arg) {
664685
}
665686

666687
template <typename T>
688+
requires TiledArray::detail::is_numeric_v<T>
667689
T product(const UMTensor<T> &arg) {
668690
detail::to_device(arg);
669691
auto stream = device::stream_for(arg.range());
@@ -674,6 +696,7 @@ T product(const UMTensor<T> &arg) {
674696
}
675697

676698
template <typename T>
699+
requires TiledArray::detail::is_numeric_v<T>
677700
T max(const UMTensor<T> &arg) {
678701
detail::to_device(arg);
679702
auto stream = device::stream_for(arg.range());
@@ -684,6 +707,7 @@ T max(const UMTensor<T> &arg) {
684707
}
685708

686709
template <typename T>
710+
requires TiledArray::detail::is_numeric_v<T>
687711
T min(const UMTensor<T> &arg) {
688712
detail::to_device(arg);
689713
auto stream = device::stream_for(arg.range());
@@ -694,6 +718,7 @@ T min(const UMTensor<T> &arg) {
694718
}
695719

696720
template <typename T>
721+
requires TiledArray::detail::is_numeric_v<T>
697722
T abs_max(const UMTensor<T> &arg) {
698723
detail::to_device(arg);
699724
auto stream = device::stream_for(arg.range());
@@ -704,6 +729,7 @@ T abs_max(const UMTensor<T> &arg) {
704729
}
705730

706731
template <typename T>
732+
requires TiledArray::detail::is_numeric_v<T>
707733
T abs_min(const UMTensor<T> &arg) {
708734
detail::to_device(arg);
709735
auto stream = device::stream_for(arg.range());
@@ -721,6 +747,7 @@ namespace madness {
721747
namespace archive {
722748

723749
template <typename Archive, typename T>
750+
requires TiledArray::detail::is_numeric_v<T>
724751
struct ArchiveStoreImpl<Archive, TiledArray::UMTensor<T>> {
725752
static inline void store(const Archive &ar,
726753
const TiledArray::UMTensor<T> &t) {
@@ -736,6 +763,7 @@ struct ArchiveStoreImpl<Archive, TiledArray::UMTensor<T>> {
736763
};
737764

738765
template <typename Archive, typename T>
766+
requires TiledArray::detail::is_numeric_v<T>
739767
struct ArchiveLoadImpl<Archive, TiledArray::UMTensor<T>> {
740768
static inline void load(const Archive &ar, TiledArray::UMTensor<T> &t) {
741769
TiledArray::Range range{};

0 commit comments

Comments
 (0)