@@ -281,8 +281,10 @@ impl<A, S, D> ArrayBase<S, D>
281281 D : Dimension ,
282282{
283283 /// Perform the operation `self += alpha * rhs` efficiently, where
284- /// `alpha` is a scalar and `rhs` is another array. This operation is
285- /// also known as `axpy` in BLAS.
284+ /// `alpha` is a scalar and `rhs` is another array.
285+ ///
286+ /// This operation calls BLAS `axpy` if blas is enabled, the scalar type
287+ /// is `f32` or `f64` and the dimensions are compatible.
286288 ///
287289 /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
288290 ///
@@ -322,25 +324,19 @@ impl<A, S, D> ArrayBase<S, D>
322324 A : LinalgScalar ,
323325 E : Dimension ,
324326 {
325- debug_assert_eq ! ( self . len( ) , rhs. len( ) ) ;
326- assert ! ( self . len( ) == rhs. len( ) ) ;
327- {
327+ if ( same_type :: < A , f32 > ( ) || same_type :: < A , f64 > ( ) ) &&
328+ self . len ( ) == rhs. len ( ) {
328329 macro_rules! axpy {
329330 ( $ty: ty, $func: ident) => { {
330- if blas_compat:: <$ty, _, _>( self ) && blas_compat:: <$ty, _, _>( rhs) {
331- let order = Dimension :: _fastest_varying_stride_order( & self . strides) ;
332- let incx = self . strides( ) [ order[ 0 ] ] ;
333-
334- let order = Dimension :: _fastest_varying_stride_order( & rhs. strides) ;
335- let incy = self . strides( ) [ order[ 0 ] ] ;
336-
331+ if let Some ( self_stride) = blas_compat_axpy:: <$ty, _, _>( self ) {
332+ if let Some ( rhs_stride) = blas_compat_axpy:: <$ty, _, _>( rhs) {
337333 unsafe {
338334 let ( lhs_ptr, n, incx) = blas_1d_params( self . ptr,
339335 self . len( ) ,
340- incx ) ;
336+ self_stride ) ;
341337 let ( rhs_ptr, _, incy) = blas_1d_params( rhs. ptr,
342- rhs . len( ) ,
343- incy ) ;
338+ self . len( ) ,
339+ rhs_stride ) ;
344340 blas_sys:: c:: $func(
345341 n,
346342 cast_as( & alpha) ,
@@ -351,8 +347,10 @@ impl<A, S, D> ArrayBase<S, D>
351347 return ;
352348 }
353349 }
354- } }
350+ }
351+ } }
355352 }
353+
356354 axpy ! { f32 , cblas_saxpy} ;
357355 axpy ! { f64 , cblas_daxpy} ;
358356 }
@@ -595,29 +593,29 @@ fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
595593}
596594
597595#[ cfg( feature="blas" ) ]
598- fn blas_compat < A , S , D > ( a : & ArrayBase < S , D > ) -> bool
596+ // Return the interelement stride if the memory layout is blas compatible
597+ fn blas_compat_axpy < A , S , D > ( a : & ArrayBase < S , D > ) -> Option < isize >
599598 where S : Data ,
600599 A : ' static ,
601600 S :: Elem : ' static ,
602601 D : Dimension ,
603602{
604603 if !same_type :: < A , S :: Elem > ( ) {
605- return false ;
604+ return None
606605 }
607606
608- match D :: equispaced_stride ( & a . raw_dim ( ) , & a . strides ) {
609- Some ( stride ) => {
610- if a . len ( ) as isize * stride > blas_index :: max_value ( ) as isize ||
611- stride < blas_index:: min_value ( ) as isize {
612- return false ;
613- }
614- } ,
615- None => {
616- return false ;
607+ // FIXME: Support more memory layouts
608+ if a . ndim ( ) == 1 {
609+ let stride = a . strides ( ) [ 0 ] ;
610+ if ( a . len ( ) * stride . abs ( ) as usize ) < blas_index:: max_value ( ) as usize {
611+ return Some ( stride )
612+ }
613+ } else if D :: is_contiguous ( & a . dim , & a . strides ) {
614+ if a . len ( ) < blas_index :: max_value ( ) as usize {
615+ return Some ( 1 )
617616 }
618617 }
619-
620- true
618+ None
621619}
622620
623621#[ cfg( feature="blas" ) ]
0 commit comments