@@ -48,12 +48,14 @@ namespace detail {
4848
4949// / is_device_tile specialization for UMTensor
5050template  <typename  T>
51+ requires  TiledArray::detail::is_numeric_v<T>
5152struct  is_device_tile <
5253    ::TiledArray::Tensor<T, TiledArray::device_um_allocator<T>>>
5354    : public std::true_type {};
5455
5556// / pre-fetch to device
5657template  <typename  T>
58+ requires  TiledArray::detail::is_numeric_v<T>
5759void  to_device (const  UMTensor<T> &tensor) {
5860  auto  stream = device::stream_for (tensor.range ());
5961  TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(tensor,
@@ -62,6 +64,7 @@ void to_device(const UMTensor<T> &tensor) {
6264
6365// / pre-fetch to host
6466template  <typename  T>
67+ requires  TiledArray::detail::is_numeric_v<T>
6568void  to_host (const  UMTensor<T> &tensor) {
6669  auto  stream = device::stream_for (tensor.range ());
6770  TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(tensor,
@@ -83,6 +86,7 @@ void to_host(const UMTensor<T> &tensor) {
8386// / handle ComplexConjugate handling for scaling functions
8487// / follows the logic in device/btas.h
8588template  <typename  T, typename  Scalar, typename  Queue>
89+ requires  TiledArray::detail::is_numeric_v<T>
8690void  apply_scale_factor (T *data, std::size_t  size, const  Scalar &factor,
8791                        Queue &queue) {
8892  if  constexpr  (TiledArray::detail::is_blas_numeric_v<Scalar> ||
@@ -111,7 +115,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
111115// /
112116
113117template  <typename  T, typename  Scalar>
114-   requires  TiledArray::detail::is_numeric_v<Scalar>
118+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
115119UMTensor<T> gemm (const  UMTensor<T> &left, const  UMTensor<T> &right,
116120                 Scalar factor,
117121                 const  TiledArray::math::GemmHelper &gemm_helper) {
@@ -166,7 +170,7 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
166170}
167171
168172template  <typename  T, typename  Scalar>
169-   requires  TiledArray::detail::is_numeric_v<Scalar>
173+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
170174void  gemm (UMTensor<T> &result, const  UMTensor<T> &left,
171175          const  UMTensor<T> &right, Scalar factor,
172176          const  TiledArray::math::GemmHelper &gemm_helper) {
@@ -230,6 +234,7 @@ void gemm(UMTensor<T> &result, const UMTensor<T> &left,
230234// /
231235
232236template  <typename  T>
237+ requires  TiledArray::detail::is_numeric_v<T>
233238UMTensor<T> clone (const  UMTensor<T> &arg) {
234239  TA_ASSERT (!arg.empty ());
235240
@@ -252,6 +257,7 @@ UMTensor<T> clone(const UMTensor<T> &arg) {
252257// /
253258
254259template  <typename  T, typename  Index>
260+ requires  TiledArray::detail::is_numeric_v<T>
255261UMTensor<T> shift (const  UMTensor<T> &arg, const  Index &bound_shift) {
256262  TA_ASSERT (!arg.empty ());
257263
@@ -276,6 +282,7 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
276282}
277283
278284template  <typename  T, typename  Index>
285+ requires  TiledArray::detail::is_numeric_v<T>
279286UMTensor<T> &shift_to (UMTensor<T> &arg, const  Index &bound_shift) {
280287  const_cast <TiledArray::Range &>(arg.range ()).inplace_shift (bound_shift);
281288  return  arg;
@@ -286,6 +293,7 @@ UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
286293// /
287294
288295template  <typename  T>
296+ requires  TiledArray::detail::is_numeric_v<T>
289297UMTensor<T> permute (const  UMTensor<T> &arg,
290298                    const  TiledArray::Permutation &perm) {
291299  TA_ASSERT (!arg.empty ());
@@ -308,6 +316,7 @@ UMTensor<T> permute(const UMTensor<T> &arg,
308316}
309317
310318template  <typename  T>
319+ requires  TiledArray::detail::is_numeric_v<T>
311320UMTensor<T> permute (const  UMTensor<T> &arg,
312321                    const  TiledArray::BipartitePermutation &perm) {
313322  TA_ASSERT (!arg.empty ());
@@ -320,7 +329,7 @@ UMTensor<T> permute(const UMTensor<T> &arg,
320329// /
321330
322331template  <typename  T, typename  Scalar>
323-   requires  TiledArray::detail::is_numeric_v<Scalar>
332+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
324333UMTensor<T> scale (const  UMTensor<T> &arg, const  Scalar factor) {
325334  auto  &queue = blasqueue_for (arg.range ());
326335  const  auto  stream = device::Stream (queue.device (), queue.stream ());
@@ -335,7 +344,7 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor) {
335344}
336345
337346template  <typename  T, typename  Scalar>
338-   requires  TiledArray::detail::is_numeric_v<Scalar>
347+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
339348UMTensor<T> &scale_to (UMTensor<T> &arg, const  Scalar factor) {
340349  auto  &queue = blasqueue_for (arg.range ());
341350  const  auto  stream = device::Stream (queue.device (), queue.stream ());
@@ -352,7 +361,7 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
352361}
353362
354363template  <typename  T, typename  Scalar, typename  Perm>
355-   requires  TiledArray::detail::is_numeric_v<Scalar> &&
364+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> && 
356365           TiledArray::detail::is_permutation_v<Perm>
357366UMTensor<T> scale (const  UMTensor<T> &arg, const  Scalar factor,
358367                  const  Perm &perm) {
@@ -365,18 +374,20 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor,
365374// /
366375
367376template  <typename  T>
377+ requires  TiledArray::detail::is_numeric_v<T>
368378UMTensor<T> neg (const  UMTensor<T> &arg) {
369379  return  scale (arg, T (-1.0 ));
370380}
371381
372382template  <typename  T, typename  Perm>
373-   requires  TiledArray::detail::is_permutation_v<Perm>
383+   requires  TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T> 
374384UMTensor<T> neg (const  UMTensor<T> &arg, const  Perm &perm) {
375385  auto  result = neg (arg);
376386  return  permute (result, perm);
377387}
378388
379389template  <typename  T>
390+ requires  TiledArray::detail::is_numeric_v<T>
380391UMTensor<T> &neg_to (UMTensor<T> &arg) {
381392  return  scale_to (arg, T (-1.0 ));
382393}
@@ -386,6 +397,7 @@ UMTensor<T> &neg_to(UMTensor<T> &arg) {
386397// /
387398
388399template  <typename  T>
400+ requires  TiledArray::detail::is_numeric_v<T>
389401UMTensor<T> add (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2) {
390402  UMTensor<T> result (arg1.range ());
391403
@@ -406,23 +418,23 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
406418}
407419
408420template  <typename  T, typename  Scalar>
409-   requires  TiledArray::detail::is_numeric_v<Scalar>
421+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
410422UMTensor<T> add (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
411423                const  Scalar factor) {
412424  auto  result = add (arg1, arg2);
413425  return  scale_to (result, factor);
414426}
415427
416428template  <typename  T, typename  Perm>
417-   requires  TiledArray::detail::is_permutation_v<Perm>
429+   requires  TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T> 
418430UMTensor<T> add (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
419431                const  Perm &perm) {
420432  auto  result = add (arg1, arg2);
421433  return  permute (result, perm);
422434}
423435
424436template  <typename  T, typename  Scalar, typename  Perm>
425-   requires  TiledArray::detail::is_numeric_v<Scalar> &&
437+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> && 
426438           TiledArray::detail::is_permutation_v<Perm>
427439UMTensor<T> add (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
428440                const  Scalar factor, const  Perm &perm) {
@@ -435,6 +447,7 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
435447// /
436448
437449template  <typename  T>
450+ requires  TiledArray::detail::is_numeric_v<T>
438451UMTensor<T> &add_to (UMTensor<T> &result, const  UMTensor<T> &arg) {
439452  auto  &queue = blasqueue_for (result.range ());
440453  const  auto  stream = device::Stream (queue.device (), queue.stream ());
@@ -450,7 +463,7 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
450463}
451464
452465template  <typename  T, typename  Scalar>
453-   requires  TiledArray::detail::is_numeric_v<Scalar>
466+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
454467UMTensor<T> &add_to (UMTensor<T> &result, const  UMTensor<T> &arg,
455468                    const  Scalar factor) {
456469  add_to (result, arg);
@@ -462,6 +475,7 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg,
462475// /
463476
464477template  <typename  T>
478+ requires  TiledArray::detail::is_numeric_v<T>
465479UMTensor<T> subt (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2) {
466480  UMTensor<T> result (arg1.range ());
467481
@@ -482,23 +496,23 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
482496}
483497
484498template  <typename  T, typename  Scalar>
485-   requires  TiledArray::detail::is_numeric_v<Scalar>
499+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
486500UMTensor<T> subt (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
487501                 const  Scalar factor) {
488502  auto  result = subt (arg1, arg2);
489503  return  scale_to (result, factor);
490504}
491505
492506template  <typename  T, typename  Perm>
493-   requires  TiledArray::detail::is_permutation_v<Perm>
507+   requires  TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T> 
494508UMTensor<T> subt (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
495509                 const  Perm &perm) {
496510  auto  result = subt (arg1, arg2);
497511  return  permute (result, perm);
498512}
499513
500514template  <typename  T, typename  Scalar, typename  Perm>
501-   requires  TiledArray::detail::is_numeric_v<Scalar> &&
515+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> && 
502516           TiledArray::detail::is_permutation_v<Perm>
503517UMTensor<T> subt (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
504518                 const  Scalar factor, const  Perm &perm) {
@@ -511,6 +525,7 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
511525// /
512526
513527template  <typename  T>
528+ requires  TiledArray::detail::is_numeric_v<T>
514529UMTensor<T> &subt_to (UMTensor<T> &result, const  UMTensor<T> &arg) {
515530  auto  &queue = blasqueue_for (result.range ());
516531  const  auto  stream = device::Stream (queue.device (), queue.stream ());
@@ -526,7 +541,7 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
526541}
527542
528543template  <typename  T, typename  Scalar>
529-   requires  TiledArray::detail::is_numeric_v<Scalar>
544+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
530545UMTensor<T> &subt_to (UMTensor<T> &result, const  UMTensor<T> &arg,
531546                     const  Scalar factor) {
532547  subt_to (result, arg);
@@ -538,6 +553,7 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg,
538553// /
539554
540555template  <typename  T>
556+ requires  TiledArray::detail::is_numeric_v<T>
541557UMTensor<T> mult (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2) {
542558  TA_ASSERT (arg1.size () == arg2.size ());
543559
@@ -557,23 +573,23 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
557573}
558574
559575template  <typename  T, typename  Scalar>
560-   requires  TiledArray::detail::is_numeric_v<Scalar>
576+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> 
561577UMTensor<T> mult (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
562578                 const  Scalar factor) {
563579  auto  result = mult (arg1, arg2);
564580  return  scale_to (result, factor);
565581}
566582
567583template  <typename  T, typename  Perm>
568-   requires  TiledArray::detail::is_permutation_v<Perm>
584+   requires  TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T> 
569585UMTensor<T> mult (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
570586                 const  Perm &perm) {
571587  auto  result = mult (arg1, arg2);
572588  return  permute (result, perm);
573589}
574590
575591template  <typename  T, typename  Scalar, typename  Perm>
576-   requires  TiledArray::detail::is_numeric_v<Scalar> &&
592+   requires  TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> && 
577593           TiledArray::detail::is_permutation_v<Perm>
578594UMTensor<T> mult (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2,
579595                 const  Scalar factor, const  Perm &perm) {
@@ -586,6 +602,7 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
586602// /
587603
588604template  <typename  T>
605+ requires  TiledArray::detail::is_numeric_v<T>
589606UMTensor<T> &mult_to (UMTensor<T> &result, const  UMTensor<T> &arg) {
590607  auto  stream = device::stream_for (result.range ());
591608  TA_ASSERT (result.size () == arg.size ());
@@ -614,6 +631,7 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg,
614631// /
615632
616633template  <typename  T>
634+ requires  TiledArray::detail::is_numeric_v<T>
617635T dot (const  UMTensor<T> &arg1, const  UMTensor<T> &arg2) {
618636  auto  &queue = blasqueue_for (arg1.range ());
619637  const  auto  stream = device::Stream (queue.device (), queue.stream ());
@@ -634,6 +652,7 @@ T dot(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
634652// /
635653
636654template  <typename  T>
655+ requires  TiledArray::detail::is_numeric_v<T>
637656T squared_norm (const  UMTensor<T> &arg) {
638657  auto  &queue = blasqueue_for (arg.range ());
639658  const  auto  stream = device::Stream (queue.device (), queue.stream ());
@@ -649,11 +668,13 @@ T squared_norm(const UMTensor<T> &arg) {
649668}
650669
651670template  <typename  T>
671+ requires  TiledArray::detail::is_numeric_v<T>
652672T norm (const  UMTensor<T> &arg) {
653673  return  std::sqrt (squared_norm (arg));
654674}
655675
656676template  <typename  T>
677+ requires  TiledArray::detail::is_numeric_v<T>
657678T sum (const  UMTensor<T> &arg) {
658679  detail::to_device (arg);
659680  auto  stream = device::stream_for (arg.range ());
@@ -664,6 +685,7 @@ T sum(const UMTensor<T> &arg) {
664685}
665686
666687template  <typename  T>
688+ requires  TiledArray::detail::is_numeric_v<T>
667689T product (const  UMTensor<T> &arg) {
668690  detail::to_device (arg);
669691  auto  stream = device::stream_for (arg.range ());
@@ -674,6 +696,7 @@ T product(const UMTensor<T> &arg) {
674696}
675697
676698template  <typename  T>
699+ requires  TiledArray::detail::is_numeric_v<T>
677700T max (const  UMTensor<T> &arg) {
678701  detail::to_device (arg);
679702  auto  stream = device::stream_for (arg.range ());
@@ -684,6 +707,7 @@ T max(const UMTensor<T> &arg) {
684707}
685708
686709template  <typename  T>
710+ requires  TiledArray::detail::is_numeric_v<T>
687711T min (const  UMTensor<T> &arg) {
688712  detail::to_device (arg);
689713  auto  stream = device::stream_for (arg.range ());
@@ -694,6 +718,7 @@ T min(const UMTensor<T> &arg) {
694718}
695719
696720template  <typename  T>
721+ requires  TiledArray::detail::is_numeric_v<T>
697722T abs_max (const  UMTensor<T> &arg) {
698723  detail::to_device (arg);
699724  auto  stream = device::stream_for (arg.range ());
@@ -704,6 +729,7 @@ T abs_max(const UMTensor<T> &arg) {
704729}
705730
706731template  <typename  T>
732+ requires  TiledArray::detail::is_numeric_v<T>
707733T abs_min (const  UMTensor<T> &arg) {
708734  detail::to_device (arg);
709735  auto  stream = device::stream_for (arg.range ());
@@ -721,6 +747,7 @@ namespace madness {
721747namespace  archive  {
722748
723749template  <typename  Archive, typename  T>
750+ requires  TiledArray::detail::is_numeric_v<T>
724751struct  ArchiveStoreImpl <Archive, TiledArray::UMTensor<T>> {
725752  static  inline  void  store (const  Archive &ar,
726753                           const  TiledArray::UMTensor<T> &t) {
@@ -736,6 +763,7 @@ struct ArchiveStoreImpl<Archive, TiledArray::UMTensor<T>> {
736763};
737764
738765template  <typename  Archive, typename  T>
766+ requires  TiledArray::detail::is_numeric_v<T>
739767struct  ArchiveLoadImpl <Archive, TiledArray::UMTensor<T>> {
740768  static  inline  void  load (const  Archive &ar, TiledArray::UMTensor<T> &t) {
741769    TiledArray::Range range{};
0 commit comments